MLIR  15.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define DEBUG_TYPE "spirv-to-llvm-pattern"
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Returns true if the given type is a signed integer or vector type.
36 static bool isSignedIntegerOrVector(Type type) {
37  if (type.isSignedInteger())
38  return true;
39  if (auto vecType = type.dyn_cast<VectorType>())
40  return vecType.getElementType().isSignedInteger();
41  return false;
42 }
43 
44 /// Returns true if the given type is an unsigned integer or vector type
45 static bool isUnsignedIntegerOrVector(Type type) {
46  if (type.isUnsignedInteger())
47  return true;
48  if (auto vecType = type.dyn_cast<VectorType>())
49  return vecType.getElementType().isUnsignedInteger();
50  return false;
51 }
52 
53 /// Returns the bit width of integer, float or vector of float or integer values
54 static unsigned getBitWidth(Type type) {
55  assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
56  "bitwidth is not supported for this type");
57  if (type.isIntOrFloat())
58  return type.getIntOrFloatBitWidth();
59  auto vecType = type.dyn_cast<VectorType>();
60  auto elementType = vecType.getElementType();
61  assert(elementType.isIntOrFloat() &&
62  "only integers and floats have a bitwidth");
63  return elementType.getIntOrFloatBitWidth();
64 }
65 
66 /// Returns the bit width of LLVMType integer or vector.
67 static unsigned getLLVMTypeBitWidth(Type type) {
69  : type)
70  .cast<IntegerType>()
71  .getWidth();
72 }
73 
74 /// Creates `IntegerAttribute` with all bits set for given type
75 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
76  if (auto vecType = type.dyn_cast<VectorType>()) {
77  auto integerType = vecType.getElementType().cast<IntegerType>();
78  return builder.getIntegerAttr(integerType, -1);
79  }
80  auto integerType = type.cast<IntegerType>();
81  return builder.getIntegerAttr(integerType, -1);
82 }
83 
84 /// Creates `llvm.mlir.constant` with all bits set for the given type.
85 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
86  PatternRewriter &rewriter) {
87  if (srcType.isa<VectorType>()) {
88  return rewriter.create<LLVM::ConstantOp>(
89  loc, dstType,
90  SplatElementsAttr::get(srcType.cast<ShapedType>(),
91  minusOneIntegerAttribute(srcType, rewriter)));
92  }
93  return rewriter.create<LLVM::ConstantOp>(
94  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
95 }
96 
97 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
98 static Value createFPConstant(Location loc, Type srcType, Type dstType,
99  PatternRewriter &rewriter, double value) {
100  if (auto vecType = srcType.dyn_cast<VectorType>()) {
101  auto floatType = vecType.getElementType().cast<FloatType>();
102  return rewriter.create<LLVM::ConstantOp>(
103  loc, dstType,
104  SplatElementsAttr::get(vecType,
105  rewriter.getFloatAttr(floatType, value)));
106  }
107  auto floatType = srcType.cast<FloatType>();
108  return rewriter.create<LLVM::ConstantOp>(
109  loc, dstType, rewriter.getFloatAttr(floatType, value));
110 }
111 
112 /// Utility function for bitfield ops:
113 /// - `BitFieldInsert`
114 /// - `BitFieldSExtract`
115 /// - `BitFieldUExtract`
116 /// Truncates or extends the value. If the bitwidth of the value is the same as
117 /// `llvmType` bitwidth, the value remains unchanged.
119  Type llvmType,
120  PatternRewriter &rewriter) {
121  auto srcType = value.getType();
122  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
123  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
124  ? getLLVMTypeBitWidth(srcType)
125  : getBitWidth(srcType);
126 
127  if (valueBitWidth < targetBitWidth)
128  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
129  // If the bit widths of `Count` and `Offset` are greater than the bit width
130  // of the target type, they are truncated. Truncation is safe since `Count`
131  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
132  // both values can be expressed in 8 bits.
133  if (valueBitWidth > targetBitWidth)
134  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
135  return value;
136 }
137 
138 /// Broadcasts the value to vector with `numElements` number of elements.
139 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
140  LLVMTypeConverter &typeConverter,
141  ConversionPatternRewriter &rewriter) {
142  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
143  auto llvmVectorType = typeConverter.convertType(vectorType);
144  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
145  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
146  for (unsigned i = 0; i < numElements; ++i) {
147  auto index = rewriter.create<LLVM::ConstantOp>(
148  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
149  broadcasted = rewriter.create<LLVM::InsertElementOp>(
150  loc, llvmVectorType, broadcasted, toBroadcast, index);
151  }
152  return broadcasted;
153 }
154 
155 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
157  LLVMTypeConverter &typeConverter,
158  ConversionPatternRewriter &rewriter) {
159  if (auto vectorType = srcType.dyn_cast<VectorType>()) {
160  unsigned numElements = vectorType.getNumElements();
161  return broadcast(loc, value, numElements, typeConverter, rewriter);
162  }
163  return value;
164 }
165 
166 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
167 /// `BitFieldUExtract`.
168 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
169 /// a vector type, construct a vector that has:
170 /// - same number of elements as `Base`
171 /// - each element has the type that is the same as the type of `Offset` or
172 /// `Count`
173 /// - each element has the same value as `Offset` or `Count`
174 /// Then cast `Offset` and `Count` if their bit width is different
175 /// from `Base` bit width.
177  Type dstType, LLVMTypeConverter &converter,
178  ConversionPatternRewriter &rewriter) {
179  Value broadcasted =
180  optionallyBroadcast(loc, value, srcType, converter, rewriter);
181  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
182 }
183 
184 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
185 /// offset to LLVM struct. Otherwise, the conversion is not supported.
186 static Optional<Type>
188  LLVMTypeConverter &converter) {
189  if (type != VulkanLayoutUtils::decorateType(type))
190  return llvm::None;
191 
192  auto elementsVector = llvm::to_vector<8>(
193  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
194  return converter.convertType(elementType);
195  }));
196  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
197  /*isPacked=*/false);
198 }
199 
200 /// Converts SPIR-V struct with no offset to packed LLVM struct.
202  LLVMTypeConverter &converter) {
203  auto elementsVector = llvm::to_vector<8>(
204  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
205  return converter.convertType(elementType);
206  }));
207  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
208  /*isPacked=*/true);
209 }
210 
211 /// Creates LLVM dialect constant with the given value.
213  unsigned value) {
214  return rewriter.create<LLVM::ConstantOp>(
215  loc, IntegerType::get(rewriter.getContext(), 32),
216  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
217 }
218 
219 /// Utility for `spv.Load` and `spv.Store` conversion.
221  ConversionPatternRewriter &rewriter,
222  LLVMTypeConverter &typeConverter,
223  unsigned alignment, bool isVolatile,
224  bool isNonTemporal) {
225  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
226  auto dstType = typeConverter.convertType(loadOp.getType());
227  if (!dstType)
228  return failure();
229  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
230  loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
231  isVolatile, isNonTemporal);
232  return success();
233  }
234  auto storeOp = cast<spirv::StoreOp>(op);
235  spirv::StoreOpAdaptor adaptor(operands);
236  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
237  adaptor.ptr(), alignment,
238  isVolatile, isNonTemporal);
239  return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Type conversion
244 //===----------------------------------------------------------------------===//
245 
246 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
247 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
248 /// when converting ops that manipulate array types.
250  TypeConverter &converter) {
251  unsigned stride = type.getArrayStride();
252  Type elementType = type.getElementType();
253  auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
254  if (stride != 0 && !(sizeInBytes && *sizeInBytes == stride))
255  return llvm::None;
256 
257  auto llvmElementType = converter.convertType(elementType);
258  unsigned numElements = type.getNumElements();
259  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
260 }
261 
262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
263 /// modelled at the moment.
265  TypeConverter &converter) {
266  auto pointeeType = converter.convertType(type.getPointeeType());
267  return LLVM::LLVMPointerType::get(pointeeType);
268 }
269 
270 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
271 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
272 /// no modelling of array stride at the moment.
274  TypeConverter &converter) {
275  if (type.getArrayStride() != 0)
276  return llvm::None;
277  auto elementType = converter.convertType(type.getElementType());
278  return LLVM::LLVMArrayType::get(elementType, 0);
279 }
280 
281 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
282 /// member decorations. Also, only natural offset is supported.
284  LLVMTypeConverter &converter) {
286  type.getMemberDecorations(memberDecorations);
287  if (!memberDecorations.empty())
288  return llvm::None;
289  if (type.hasOffset())
290  return convertStructTypeWithOffset(type, converter);
291  return convertStructTypePacked(type, converter);
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Operation conversion
296 //===----------------------------------------------------------------------===//
297 
298 namespace {
299 
300 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
301 public:
303 
305  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
306  ConversionPatternRewriter &rewriter) const override {
307  auto dstType = typeConverter.convertType(op.component_ptr().getType());
308  if (!dstType)
309  return failure();
310  // To use GEP we need to add a first 0 index to go through the pointer.
311  auto indices = llvm::to_vector<4>(adaptor.indices());
312  Type indexType = op.indices().front().getType();
313  auto llvmIndexType = typeConverter.convertType(indexType);
314  if (!llvmIndexType)
315  return failure();
316  Value zero = rewriter.create<LLVM::ConstantOp>(
317  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
318  indices.insert(indices.begin(), zero);
319  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
320  indices);
321  return success();
322  }
323 };
324 
325 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
326 public:
328 
330  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
331  ConversionPatternRewriter &rewriter) const override {
332  auto dstType = typeConverter.convertType(op.pointer().getType());
333  if (!dstType)
334  return failure();
335  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
336  return success();
337  }
338 };
339 
340 class BitFieldInsertPattern
341  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
342 public:
344 
346  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
347  ConversionPatternRewriter &rewriter) const override {
348  auto srcType = op.getType();
349  auto dstType = typeConverter.convertType(srcType);
350  if (!dstType)
351  return failure();
352  Location loc = op.getLoc();
353 
354  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
355  Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
356  typeConverter, rewriter);
357  Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
358  typeConverter, rewriter);
359 
360  // Create a mask with bits set outside [Offset, Offset + Count - 1].
361  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
362  Value maskShiftedByCount =
363  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
364  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
365  maskShiftedByCount, minusOne);
366  Value maskShiftedByCountAndOffset =
367  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
368  Value mask = rewriter.create<LLVM::XOrOp>(
369  loc, dstType, maskShiftedByCountAndOffset, minusOne);
370 
371  // Extract unchanged bits from the `Base` that are outside of
372  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
373  Value baseAndMask =
374  rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
375  Value insertShiftedByOffset =
376  rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
377  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
378  insertShiftedByOffset);
379  return success();
380  }
381 };
382 
383 /// Converts SPIR-V ConstantOp with scalar or vector type.
384 class ConstantScalarAndVectorPattern
385  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
386 public:
388 
390  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
391  ConversionPatternRewriter &rewriter) const override {
392  auto srcType = constOp.getType();
393  if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
394  return failure();
395 
396  auto dstType = typeConverter.convertType(srcType);
397  if (!dstType)
398  return failure();
399 
400  // SPIR-V constant can be a signed/unsigned integer, which has to be
401  // casted to signless integer when converting to LLVM dialect. Removing the
402  // sign bit may have unexpected behaviour. However, it is better to handle
403  // it case-by-case, given that the purpose of the conversion is not to
404  // cover all possible corner cases.
405  if (isSignedIntegerOrVector(srcType) ||
406  isUnsignedIntegerOrVector(srcType)) {
407  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
408 
409  if (srcType.isa<VectorType>()) {
410  auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
411  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
412  constOp, dstType,
413  dstElementsAttr.mapValues(
414  signlessType, [&](const APInt &value) { return value; }));
415  return success();
416  }
417  auto srcAttr = constOp.value().cast<IntegerAttr>();
418  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
419  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
420  return success();
421  }
422  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
423  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
424  return success();
425  }
426 };
427 
428 class BitFieldSExtractPattern
429  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
430 public:
432 
434  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
435  ConversionPatternRewriter &rewriter) const override {
436  auto srcType = op.getType();
437  auto dstType = typeConverter.convertType(srcType);
438  if (!dstType)
439  return failure();
440  Location loc = op.getLoc();
441 
442  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
443  Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
444  typeConverter, rewriter);
445  Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
446  typeConverter, rewriter);
447 
448  // Create a constant that holds the size of the `Base`.
449  IntegerType integerType;
450  if (auto vecType = srcType.dyn_cast<VectorType>())
451  integerType = vecType.getElementType().cast<IntegerType>();
452  else
453  integerType = srcType.cast<IntegerType>();
454 
455  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
456  Value size =
457  srcType.isa<VectorType>()
458  ? rewriter.create<LLVM::ConstantOp>(
459  loc, dstType,
460  SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
461  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
462 
463  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
464  // at Offset + Count - 1 is the most significant bit now.
465  Value countPlusOffset =
466  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
467  Value amountToShiftLeft =
468  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
469  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
470  loc, dstType, op.base(), amountToShiftLeft);
471 
472  // Shift the result right, filling the bits with the sign bit.
473  Value amountToShiftRight =
474  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
475  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
476  amountToShiftRight);
477  return success();
478  }
479 };
480 
481 class BitFieldUExtractPattern
482  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
483 public:
485 
487  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
488  ConversionPatternRewriter &rewriter) const override {
489  auto srcType = op.getType();
490  auto dstType = typeConverter.convertType(srcType);
491  if (!dstType)
492  return failure();
493  Location loc = op.getLoc();
494 
495  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
496  Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
497  typeConverter, rewriter);
498  Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
499  typeConverter, rewriter);
500 
501  // Create a mask with bits set at [0, Count - 1].
502  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
503  Value maskShiftedByCount =
504  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
505  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
506  minusOne);
507 
508  // Shift `Base` by `Offset` and apply the mask on it.
509  Value shiftedBase =
510  rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
511  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
512  return success();
513  }
514 };
515 
516 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
517 public:
519 
521  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
522  ConversionPatternRewriter &rewriter) const override {
523  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
524  branchOp.getTarget());
525  return success();
526  }
527 };
528 
529 class BranchConditionalConversionPattern
530  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
531 public:
532  using SPIRVToLLVMConversion<
533  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
534 
536  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
537  ConversionPatternRewriter &rewriter) const override {
538  // If branch weights exist, map them to 32-bit integer vector.
539  ElementsAttr branchWeights = nullptr;
540  if (auto weights = op.branch_weights()) {
541  VectorType weightType = VectorType::get(2, rewriter.getI32Type());
542  branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
543  }
544 
545  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
546  op, op.condition(), op.getTrueBlockArguments(),
547  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
548  op.getFalseBlock());
549  return success();
550  }
551 };
552 
553 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
554 /// is an aggregate type (struct or array). Otherwise, converts to
555 /// `llvm.extractelement` that operates on vectors.
556 class CompositeExtractPattern
557  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
558 public:
560 
562  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
563  ConversionPatternRewriter &rewriter) const override {
564  auto dstType = this->typeConverter.convertType(op.getType());
565  if (!dstType)
566  return failure();
567 
568  Type containerType = op.composite().getType();
569  if (containerType.isa<VectorType>()) {
570  Location loc = op.getLoc();
571  IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
572  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
573  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
574  op, dstType, adaptor.composite(), index);
575  return success();
576  }
577  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
578  op, dstType, adaptor.composite(), op.indices());
579  return success();
580  }
581 };
582 
583 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
584 /// is an aggregate type (struct or array). Otherwise, converts to
585 /// `llvm.insertelement` that operates on vectors.
586 class CompositeInsertPattern
587  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
588 public:
590 
592  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
593  ConversionPatternRewriter &rewriter) const override {
594  auto dstType = this->typeConverter.convertType(op.getType());
595  if (!dstType)
596  return failure();
597 
598  Type containerType = op.composite().getType();
599  if (containerType.isa<VectorType>()) {
600  Location loc = op.getLoc();
601  IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
602  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
603  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
604  op, dstType, adaptor.composite(), adaptor.object(), index);
605  return success();
606  }
607  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
608  op, dstType, adaptor.composite(), adaptor.object(), op.indices());
609  return success();
610  }
611 };
612 
613 /// Converts SPIR-V operations that have straightforward LLVM equivalent
614 /// into LLVM dialect operations.
615 template <typename SPIRVOp, typename LLVMOp>
616 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
617 public:
619 
621  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
622  ConversionPatternRewriter &rewriter) const override {
623  auto dstType = this->typeConverter.convertType(operation.getType());
624  if (!dstType)
625  return failure();
626  rewriter.template replaceOpWithNewOp<LLVMOp>(
627  operation, dstType, adaptor.getOperands(), operation->getAttrs());
628  return success();
629  }
630 };
631 
632 /// Converts `spv.ExecutionMode` into a global struct constant that holds
633 /// execution mode information.
634 class ExecutionModePattern
635  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
636 public:
638 
640  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
641  ConversionPatternRewriter &rewriter) const override {
642  // First, create the global struct's name that would be associated with
643  // this entry point's execution mode. We set it to be:
644  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
645  ModuleOp module = op->getParentOfType<ModuleOp>();
646  IntegerAttr executionModeAttr = op.execution_modeAttr();
647  std::string moduleName;
648  if (module.getName().hasValue())
649  moduleName = "_" + module.getName().getValue().str();
650  else
651  moduleName = "";
652  std::string executionModeInfoName =
653  llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
654  op.fn().str(), executionModeAttr.getValue());
655 
656  MLIRContext *context = rewriter.getContext();
657  OpBuilder::InsertionGuard guard(rewriter);
658  rewriter.setInsertionPointToStart(module.getBody());
659 
660  // Create a struct type, corresponding to the C struct below.
661  // struct {
662  // int32_t executionMode;
663  // int32_t values[]; // optional values
664  // };
665  auto llvmI32Type = IntegerType::get(context, 32);
666  SmallVector<Type, 2> fields;
667  fields.push_back(llvmI32Type);
668  ArrayAttr values = op.values();
669  if (!values.empty()) {
670  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
671  fields.push_back(arrayType);
672  }
673  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
674 
675  // Create `llvm.mlir.global` with initializer region containing one block.
676  auto global = rewriter.create<LLVM::GlobalOp>(
677  UnknownLoc::get(context), structType, /*isConstant=*/true,
678  LLVM::Linkage::External, executionModeInfoName, Attribute(),
679  /*alignment=*/0);
680  Location loc = global.getLoc();
681  Region &region = global.getInitializerRegion();
682  Block *block = rewriter.createBlock(&region);
683 
684  // Initialize the struct and set the execution mode value.
685  rewriter.setInsertionPoint(block, block->begin());
686  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
687  Value executionMode =
688  rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
689  structValue = rewriter.create<LLVM::InsertValueOp>(
690  loc, structType, structValue, executionMode,
691  ArrayAttr::get(context,
692  {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
693 
694  // Insert extra operands if they exist into execution mode info struct.
695  for (unsigned i = 0, e = values.size(); i < e; ++i) {
696  auto attr = values.getValue()[i];
697  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
698  structValue = rewriter.create<LLVM::InsertValueOp>(
699  loc, structType, structValue, entry,
700  ArrayAttr::get(context,
701  {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
702  rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
703  }
704  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
705  rewriter.eraseOp(op);
706  return success();
707  }
708 };
709 
710 /// Converts `spv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V global
711 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
712 /// This difference is handled by `spv.mlir.addressof` and
713 /// `llvm.mlir.addressof`ops that both return a pointer.
714 class GlobalVariablePattern
715  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
716 public:
718 
720  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
721  ConversionPatternRewriter &rewriter) const override {
722  // Currently, there is no support of initialization with a constant value in
723  // SPIR-V dialect. Specialization constants are not considered as well.
724  if (op.initializer())
725  return failure();
726 
727  auto srcType = op.type().cast<spirv::PointerType>();
728  auto dstType = typeConverter.convertType(srcType.getPointeeType());
729  if (!dstType)
730  return failure();
731 
732  // Limit conversion to the current invocation only or `StorageBuffer`
733  // required by SPIR-V runner.
734  // This is okay because multiple invocations are not supported yet.
735  auto storageClass = srcType.getStorageClass();
736  switch (storageClass) {
737  case spirv::StorageClass::Input:
738  case spirv::StorageClass::Private:
739  case spirv::StorageClass::Output:
740  case spirv::StorageClass::StorageBuffer:
741  case spirv::StorageClass::UniformConstant:
742  break;
743  default:
744  return failure();
745  }
746 
747  // LLVM dialect spec: "If the global value is a constant, storing into it is
748  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
749  // storage class that is read-only.
750  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
751  (storageClass == spirv::StorageClass::UniformConstant);
752  // SPIR-V spec: "By default, functions and global variables are private to a
753  // module and cannot be accessed by other modules. However, a module may be
754  // written to export or import functions and global (module scope)
755  // variables.". Therefore, map 'Private' storage class to private linkage,
756  // 'Input' and 'Output' to external linkage.
757  auto linkage = storageClass == spirv::StorageClass::Private
758  ? LLVM::Linkage::Private
759  : LLVM::Linkage::External;
760  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
761  op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
762  /*alignment=*/0);
763 
764  // Attach location attribute if applicable
765  if (op.locationAttr())
766  newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
767 
768  return success();
769  }
770 };
771 
772 /// Converts SPIR-V cast ops that do not have straightforward LLVM
773 /// equivalent in LLVM dialect.
774 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
775 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
776 public:
778 
780  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
781  ConversionPatternRewriter &rewriter) const override {
782 
783  Type fromType = operation.operand().getType();
784  Type toType = operation.getType();
785 
786  auto dstType = this->typeConverter.convertType(toType);
787  if (!dstType)
788  return failure();
789 
790  if (getBitWidth(fromType) < getBitWidth(toType)) {
791  rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
792  adaptor.getOperands());
793  return success();
794  }
795  if (getBitWidth(fromType) > getBitWidth(toType)) {
796  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
797  adaptor.getOperands());
798  return success();
799  }
800  return failure();
801  }
802 };
803 
804 class FunctionCallPattern
805  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
806 public:
808 
810  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
811  ConversionPatternRewriter &rewriter) const override {
812  if (callOp.getNumResults() == 0) {
813  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
814  callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs());
815  return success();
816  }
817 
818  // Function returns a single result.
819  auto dstType = typeConverter.convertType(callOp.getType(0));
820  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
821  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
822  return success();
823  }
824 };
825 
826 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
827 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
828 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
829 public:
831 
833  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
834  ConversionPatternRewriter &rewriter) const override {
835 
836  auto dstType = this->typeConverter.convertType(operation.getType());
837  if (!dstType)
838  return failure();
839 
840  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
841  operation, dstType, predicate, operation.operand1(),
842  operation.operand2());
843  return success();
844  }
845 };
846 
847 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
848 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
849 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
850 public:
852 
854  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
855  ConversionPatternRewriter &rewriter) const override {
856 
857  auto dstType = this->typeConverter.convertType(operation.getType());
858  if (!dstType)
859  return failure();
860 
861  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
862  operation, dstType, predicate, operation.operand1(),
863  operation.operand2());
864  return success();
865  }
866 };
867 
868 class InverseSqrtPattern
869  : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
870 public:
872 
874  matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor,
875  ConversionPatternRewriter &rewriter) const override {
876  auto srcType = op.getType();
877  auto dstType = typeConverter.convertType(srcType);
878  if (!dstType)
879  return failure();
880 
881  Location loc = op.getLoc();
882  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
883  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
884  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
885  return success();
886  }
887 };
888 
889 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
890 template <typename SPIRVOp>
891 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
892 public:
894 
896  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
897  ConversionPatternRewriter &rewriter) const override {
898  if (!op.memory_access()) {
899  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
900  this->typeConverter, /*alignment=*/0,
901  /*isVolatile=*/false,
902  /*isNonTemporal=*/false);
903  }
904  auto memoryAccess = *op.memory_access();
905  switch (memoryAccess) {
906  case spirv::MemoryAccess::Aligned:
907  case spirv::MemoryAccess::None:
908  case spirv::MemoryAccess::Nontemporal:
909  case spirv::MemoryAccess::Volatile: {
910  unsigned alignment =
911  memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
912  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
913  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
914  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
915  this->typeConverter, alignment, isVolatile,
916  isNonTemporal);
917  }
918  default:
919  // There is no support of other memory access attributes.
920  return failure();
921  }
922  }
923 };
924 
925 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
926 template <typename SPIRVOp>
927 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
928 public:
930 
932  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
933  ConversionPatternRewriter &rewriter) const override {
934  auto srcType = notOp.getType();
935  auto dstType = this->typeConverter.convertType(srcType);
936  if (!dstType)
937  return failure();
938 
939  Location loc = notOp.getLoc();
940  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
941  auto mask = srcType.template isa<VectorType>()
942  ? rewriter.create<LLVM::ConstantOp>(
943  loc, dstType,
945  srcType.template cast<VectorType>(), minusOne))
946  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
947  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
948  notOp.operand(), mask);
949  return success();
950  }
951 };
952 
953 /// A template pattern that erases the given `SPIRVOp`.
954 template <typename SPIRVOp>
955 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
956 public:
958 
960  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
961  ConversionPatternRewriter &rewriter) const override {
962  rewriter.eraseOp(op);
963  return success();
964  }
965 };
966 
967 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
968 public:
970 
972  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
973  ConversionPatternRewriter &rewriter) const override {
974  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
975  ArrayRef<Value>());
976  return success();
977  }
978 };
979 
980 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
981 public:
983 
985  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
986  ConversionPatternRewriter &rewriter) const override {
987  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
988  adaptor.getOperands());
989  return success();
990  }
991 };
992 
993 /// Converts `spv.mlir.loop` to LLVM dialect. All blocks within selection should
994 /// be reachable for conversion to succeed. The structure of the loop in LLVM
995 /// dialect will be the following:
996 ///
997 /// +------------------------------------+
998 /// | <code before spv.mlir.loop> |
999 /// | llvm.br ^header |
1000 /// +------------------------------------+
1001 /// |
1002 /// +----------------+ |
1003 /// | | |
1004 /// | V V
1005 /// | +------------------------------------+
1006 /// | | ^header: |
1007 /// | | <header code> |
1008 /// | | llvm.cond_br %cond, ^body, ^exit |
1009 /// | +------------------------------------+
1010 /// | |
1011 /// | |----------------------+
1012 /// | | |
1013 /// | V |
1014 /// | +------------------------------------+ |
1015 /// | | ^body: | |
1016 /// | | <body code> | |
1017 /// | | llvm.br ^continue | |
1018 /// | +------------------------------------+ |
1019 /// | | |
1020 /// | V |
1021 /// | +------------------------------------+ |
1022 /// | | ^continue: | |
1023 /// | | <continue code> | |
1024 /// | | llvm.br ^header | |
1025 /// | +------------------------------------+ |
1026 /// | | |
1027 /// +---------------+ +----------------------+
1028 /// |
1029 /// V
1030 /// +------------------------------------+
1031 /// | ^exit: |
1032 /// | llvm.br ^remaining |
1033 /// +------------------------------------+
1034 /// |
1035 /// V
1036 /// +------------------------------------+
1037 /// | ^remaining: |
1038 /// | <code after spv.mlir.loop> |
1039 /// +------------------------------------+
1040 ///
1041 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1042 public:
1044 
1046  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1047  ConversionPatternRewriter &rewriter) const override {
1048  // There is no support of loop control at the moment.
1049  if (loopOp.loop_control() != spirv::LoopControl::None)
1050  return failure();
1051 
1052  Location loc = loopOp.getLoc();
1053 
1054  // Split the current block after `spv.mlir.loop`. The remaining ops will be
1055  // used in `endBlock`.
1056  Block *currentBlock = rewriter.getBlock();
1057  auto position = Block::iterator(loopOp);
1058  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1059 
1060  // Remove entry block and create a branch in the current block going to the
1061  // header block.
1062  Block *entryBlock = loopOp.getEntryBlock();
1063  assert(entryBlock->getOperations().size() == 1);
1064  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1065  if (!brOp)
1066  return failure();
1067  Block *headerBlock = loopOp.getHeaderBlock();
1068  rewriter.setInsertionPointToEnd(currentBlock);
1069  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1070  rewriter.eraseBlock(entryBlock);
1071 
1072  // Branch from merge block to end block.
1073  Block *mergeBlock = loopOp.getMergeBlock();
1074  Operation *terminator = mergeBlock->getTerminator();
1075  ValueRange terminatorOperands = terminator->getOperands();
1076  rewriter.setInsertionPointToEnd(mergeBlock);
1077  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1078 
1079  rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1080  rewriter.replaceOp(loopOp, endBlock->getArguments());
1081  return success();
1082  }
1083 };
1084 
1085 /// Converts `spv.mlir.selection` with `spv.BranchConditional` in its header
1086 /// block. All blocks within selection should be reachable for conversion to
1087 /// succeed.
1088 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1089 public:
1091 
1093  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1094  ConversionPatternRewriter &rewriter) const override {
1095  // There is no support for `Flatten` or `DontFlatten` selection control at
1096  // the moment. This are just compiler hints and can be performed during the
1097  // optimization passes.
1098  if (op.selection_control() != spirv::SelectionControl::None)
1099  return failure();
1100 
1101  // `spv.mlir.selection` should have at least two blocks: one selection
1102  // header block and one merge block. If no blocks are present, or control
1103  // flow branches straight to merge block (two blocks are present), the op is
1104  // redundant and it is erased.
1105  if (op.body().getBlocks().size() <= 2) {
1106  rewriter.eraseOp(op);
1107  return success();
1108  }
1109 
1110  Location loc = op.getLoc();
1111 
1112  // Split the current block after `spv.mlir.selection`. The remaining ops
1113  // will be used in `continueBlock`.
1114  auto *currentBlock = rewriter.getInsertionBlock();
1115  rewriter.setInsertionPointAfter(op);
1116  auto position = rewriter.getInsertionPoint();
1117  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1118 
1119  // Extract conditional branch information from the header block. By SPIR-V
1120  // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1121  // op. Note that `spv.Switch op` is not supported at the moment in the
1122  // SPIR-V dialect. Remove this block when finished.
1123  auto *headerBlock = op.getHeaderBlock();
1124  assert(headerBlock->getOperations().size() == 1);
1125  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1126  headerBlock->getOperations().front());
1127  if (!condBrOp)
1128  return failure();
1129  rewriter.eraseBlock(headerBlock);
1130 
1131  // Branch from merge block to continue block.
1132  auto *mergeBlock = op.getMergeBlock();
1133  Operation *terminator = mergeBlock->getTerminator();
1134  ValueRange terminatorOperands = terminator->getOperands();
1135  rewriter.setInsertionPointToEnd(mergeBlock);
1136  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1137 
1138  // Link current block to `true` and `false` blocks within the selection.
1139  Block *trueBlock = condBrOp.getTrueBlock();
1140  Block *falseBlock = condBrOp.getFalseBlock();
1141  rewriter.setInsertionPointToEnd(currentBlock);
1142  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1143  condBrOp.trueTargetOperands(), falseBlock,
1144  condBrOp.falseTargetOperands());
1145 
1146  rewriter.inlineRegionBefore(op.body(), continueBlock);
1147  rewriter.replaceOp(op, continueBlock->getArguments());
1148  return success();
1149  }
1150 };
1151 
1152 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1153 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1154 /// `Shift` is zero or sign extended to match this specification. Cases when
1155 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1156 template <typename SPIRVOp, typename LLVMOp>
1157 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1158 public:
1160 
1162  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1163  ConversionPatternRewriter &rewriter) const override {
1164 
1165  auto dstType = this->typeConverter.convertType(operation.getType());
1166  if (!dstType)
1167  return failure();
1168 
1169  Type op1Type = operation.operand1().getType();
1170  Type op2Type = operation.operand2().getType();
1171 
1172  if (op1Type == op2Type) {
1173  rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1174  adaptor.getOperands());
1175  return success();
1176  }
1177 
1178  Location loc = operation.getLoc();
1179  Value extended;
1180  if (isUnsignedIntegerOrVector(op2Type)) {
1181  extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1182  adaptor.operand2());
1183  } else {
1184  extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1185  adaptor.operand2());
1186  }
1187  Value result = rewriter.template create<LLVMOp>(
1188  loc, dstType, adaptor.operand1(), extended);
1189  rewriter.replaceOp(operation, result);
1190  return success();
1191  }
1192 };
1193 
1194 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1195 public:
1197 
1199  matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor,
1200  ConversionPatternRewriter &rewriter) const override {
1201  auto dstType = typeConverter.convertType(tanOp.getType());
1202  if (!dstType)
1203  return failure();
1204 
1205  Location loc = tanOp.getLoc();
1206  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1207  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1208  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1209  return success();
1210  }
1211 };
1212 
1213 /// Convert `spv.Tanh` to
1214 ///
1215 /// exp(2x) - 1
1216 /// -----------
1217 /// exp(2x) + 1
1218 ///
1219 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1220 public:
1222 
1224  matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor,
1225  ConversionPatternRewriter &rewriter) const override {
1226  auto srcType = tanhOp.getType();
1227  auto dstType = typeConverter.convertType(srcType);
1228  if (!dstType)
1229  return failure();
1230 
1231  Location loc = tanhOp.getLoc();
1232  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1233  Value multiplied =
1234  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1235  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1236  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1237  Value numerator =
1238  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1239  Value denominator =
1240  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1241  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1242  denominator);
1243  return success();
1244  }
1245 };
1246 
1247 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1248 public:
1250 
1252  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1253  ConversionPatternRewriter &rewriter) const override {
1254  auto srcType = varOp.getType();
1255  // Initialization is supported for scalars and vectors only.
1256  auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1257  auto init = varOp.initializer();
1258  if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1259  return failure();
1260 
1261  auto dstType = typeConverter.convertType(srcType);
1262  if (!dstType)
1263  return failure();
1264 
1265  Location loc = varOp.getLoc();
1266  Value size = createI32ConstantOf(loc, rewriter, 1);
1267  if (!init) {
1268  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1269  return success();
1270  }
1271  Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1272  rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
1273  rewriter.replaceOp(varOp, allocated);
1274  return success();
1275  }
1276 };
1277 
1278 //===----------------------------------------------------------------------===//
1279 // FuncOp conversion
1280 //===----------------------------------------------------------------------===//
1281 
1282 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1283 public:
1285 
1287  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1288  ConversionPatternRewriter &rewriter) const override {
1289 
1290  // Convert function signature. At the moment LLVMType converter is enough
1291  // for currently supported types.
1292  auto funcType = funcOp.getFunctionType();
1293  TypeConverter::SignatureConversion signatureConverter(
1294  funcType.getNumInputs());
1295  auto llvmType = typeConverter.convertFunctionSignature(
1296  funcType, /*isVariadic=*/false, signatureConverter);
1297  if (!llvmType)
1298  return failure();
1299 
1300  // Create a new `LLVMFuncOp`
1301  Location loc = funcOp.getLoc();
1302  StringRef name = funcOp.getName();
1303  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1304 
1305  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1306  MLIRContext *context = funcOp.getContext();
1307  switch (funcOp.function_control()) {
1308 #define DISPATCH(functionControl, llvmAttr) \
1309  case functionControl: \
1310  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1311  break;
1312 
1313  DISPATCH(spirv::FunctionControl::Inline,
1314  StringAttr::get(context, "alwaysinline"));
1315  DISPATCH(spirv::FunctionControl::DontInline,
1316  StringAttr::get(context, "noinline"));
1317  DISPATCH(spirv::FunctionControl::Pure,
1318  StringAttr::get(context, "readonly"));
1319  DISPATCH(spirv::FunctionControl::Const,
1320  StringAttr::get(context, "readnone"));
1321 
1322 #undef DISPATCH
1323 
1324  // Default: if `spirv::FunctionControl::None`, then no attributes are
1325  // needed.
1326  default:
1327  break;
1328  }
1329 
1330  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1331  newFuncOp.end());
1332  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1333  &signatureConverter))) {
1334  return failure();
1335  }
1336  rewriter.eraseOp(funcOp);
1337  return success();
1338  }
1339 };
1340 
1341 //===----------------------------------------------------------------------===//
1342 // ModuleOp conversion
1343 //===----------------------------------------------------------------------===//
1344 
1345 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1346 public:
1348 
1350  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1351  ConversionPatternRewriter &rewriter) const override {
1352 
1353  auto newModuleOp =
1354  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1355  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1356 
1357  // Remove the terminator block that was automatically added by builder
1358  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1359  rewriter.eraseOp(spvModuleOp);
1360  return success();
1361  }
1362 };
1363 
1364 //===----------------------------------------------------------------------===//
1365 // VectorShuffleOp conversion
1366 //===----------------------------------------------------------------------===//
1367 
1368 class VectorShufflePattern
1369  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1370 public:
1373  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1374  ConversionPatternRewriter &rewriter) const override {
1375  Location loc = op.getLoc();
1376  auto components = adaptor.components();
1377  auto vector1 = adaptor.vector1();
1378  auto vector2 = adaptor.vector2();
1379  int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
1380  int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
1381  if (vector1Size == vector2Size) {
1382  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
1383  components);
1384  return success();
1385  }
1386 
1387  auto dstType = typeConverter.convertType(op.getType());
1388  auto scalarType = dstType.cast<VectorType>().getElementType();
1389  auto componentsArray = components.getValue();
1390  auto *context = rewriter.getContext();
1391  auto llvmI32Type = IntegerType::get(context, 32);
1392  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1393  for (unsigned i = 0; i < componentsArray.size(); i++) {
1394  if (componentsArray[i].isa<IntegerAttr>())
1395  op.emitError("unable to support non-constant component");
1396 
1397  int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1398  if (indexVal == -1)
1399  continue;
1400 
1401  int offsetVal = 0;
1402  Value baseVector = vector1;
1403  if (indexVal >= vector1Size) {
1404  offsetVal = vector1Size;
1405  baseVector = vector2;
1406  }
1407 
1408  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1409  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1410  Value index = rewriter.create<LLVM::ConstantOp>(
1411  loc, llvmI32Type,
1412  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1413 
1414  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1415  loc, scalarType, baseVector, index);
1416  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1417  extractOp, dstIndex);
1418  }
1419  rewriter.replaceOp(op, targetOp);
1420  return success();
1421  }
1422 };
1423 } // namespace
1424 
1425 //===----------------------------------------------------------------------===//
1426 // Pattern population
1427 //===----------------------------------------------------------------------===//
1428 
1430  typeConverter.addConversion([&](spirv::ArrayType type) {
1431  return convertArrayType(type, typeConverter);
1432  });
1433  typeConverter.addConversion([&](spirv::PointerType type) {
1434  return convertPointerType(type, typeConverter);
1435  });
1436  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1437  return convertRuntimeArrayType(type, typeConverter);
1438  });
1439  typeConverter.addConversion([&](spirv::StructType type) {
1440  return convertStructType(type, typeConverter);
1441  });
1442 }
1443 
1445  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1446  patterns.add<
1447  // Arithmetic ops
1448  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1449  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1450  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1451  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1452  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1453  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1454  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1455  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1456  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1457  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1458  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1459  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1460  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1461 
1462  // Bitwise ops
1463  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1464  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1465  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1466  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1467  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1468  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1469  NotPattern<spirv::NotOp>,
1470 
1471  // Cast ops
1472  DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1473  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1474  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1475  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1476  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1477  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1478  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1479  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1480 
1481  // Comparison ops
1482  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1483  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1484  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1485  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1486  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1487  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1488  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1489  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1490  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1491  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1492  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1493  LLVM::FCmpPredicate::uge>,
1494  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1495  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1496  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1497  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1498  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1499  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1500  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1501  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1502  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1503  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1504  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1505 
1506  // Constant op
1507  ConstantScalarAndVectorPattern,
1508 
1509  // Control Flow ops
1510  BranchConversionPattern, BranchConditionalConversionPattern,
1511  FunctionCallPattern, LoopPattern, SelectionPattern,
1512  ErasePattern<spirv::MergeOp>,
1513 
1514  // Entry points and execution mode are handled separately.
1515  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1516 
1517  // GLSL extended instruction set ops
1518  DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1519  DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1520  DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1521  DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1522  DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1523  DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1524  DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1525  DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1526  DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1527  DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1528  DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1529  DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1530  InverseSqrtPattern, TanPattern, TanhPattern,
1531 
1532  // Logical ops
1533  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1534  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1535  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1536  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1537  NotPattern<spirv::LogicalNotOp>,
1538 
1539  // Memory ops
1540  AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1541  LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1542  VariablePattern,
1543 
1544  // Miscellaneous ops
1545  CompositeExtractPattern, CompositeInsertPattern,
1546  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1547  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1548  VectorShufflePattern,
1549 
1550  // Shift ops
1551  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1552  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1553  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1554 
1555  // Return ops
1556  ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1557 }
1558 
1560  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1561  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1562 }
1563 
1565  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1566  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1567 }
1568 
1569 //===----------------------------------------------------------------------===//
1570 // Pre-conversion hooks
1571 //===----------------------------------------------------------------------===//
1572 
1573 /// Hook for descriptor set and binding number encoding.
1574 static constexpr StringRef kBinding = "binding";
1575 static constexpr StringRef kDescriptorSet = "descriptor_set";
1576 void mlir::encodeBindAttribute(ModuleOp module) {
1577  auto spvModules = module.getOps<spirv::ModuleOp>();
1578  for (auto spvModule : spvModules) {
1579  spvModule.walk([&](spirv::GlobalVariableOp op) {
1580  IntegerAttr descriptorSet =
1581  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1582  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1583  // For every global variable in the module, get the ones with descriptor
1584  // set and binding numbers.
1585  if (descriptorSet && binding) {
1586  // Encode these numbers into the variable's symbolic name. If the
1587  // SPIR-V module has a name, add it at the beginning.
1588  auto moduleAndName = spvModule.getName().hasValue()
1589  ? spvModule.getName().getValue().str() + "_" +
1590  op.sym_name().str()
1591  : op.sym_name().str();
1592  std::string name =
1593  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1594  std::to_string(descriptorSet.getInt()),
1595  std::to_string(binding.getInt()));
1596  auto nameAttr = StringAttr::get(op->getContext(), name);
1597 
1598  // Replace all symbol uses and set the new symbol name. Finally, remove
1599  // descriptor set and binding attributes.
1600  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1601  op.emitError("unable to replace all symbol uses for ") << name;
1602  SymbolTable::setSymbolName(op, nameAttr);
1603  op->removeAttr(kDescriptorSet);
1604  op->removeAttr(kBinding);
1605  }
1606  });
1607  }
1608 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
U cast() const
Definition: Location.h:67
MLIRContext * getContext() const
Definition: Builders.h:54
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:380
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spv.Load and spv.Store conversion.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:414
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:64
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:302
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
static LLVMArrayType get(Type elementType, unsigned numElements)
Gets or creates an instance of LLVM dialect array type containing numElements of elementType, in the same context as elementType.
Definition: LLVMTypes.cpp:39
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:336
OpListType & getOperations()
Definition: Block.h:128
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
static Optional< Type > convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:75
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:61
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:36
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
static Optional< Type > convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
#define DISPATCH(functionControl, llvmAttr)
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:386
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:350
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:822
OpListType::iterator iterator
Definition: Block.h:131
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:256
static Type convertPointerType(spirv::PointerType type, TypeConverter &converter)
Converts SPIR-V pointer type to LLVM pointer.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
Type getElementType() const
Definition: SPIRVTypes.cpp:62
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns that convert from SPIR-V to LLVM.
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:188
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:54
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:49
BlockArgListType getArguments()
Definition: Block.h:76
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:67
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
static Optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
ElementTypeRange getElementTypes() const
bool isa() const
Definition: Value.h:90
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
void addConversion(FnT &&callback)
Register a conversion function.
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:458
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Type conversion class.
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:753
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect. ...
Definition: LLVMTypes.cpp:840
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
U cast() const
Definition: Value.h:108
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:374
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void encodeBindAttribute(ModuleOp module)
Encodes global variable&#39;s descriptor set and binding into its name if they both exist.
static constexpr StringRef kDescriptorSet
static Optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol &#39;oldSymbol&#39; with the provided symbol &#39;newSymbol&#39; that...
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:85
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:383
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
bool isa() const
Definition: Types.h:246
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:45
This class provides an abstraction over the different types of ranges over Values.
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
MLIRContext * getContext() const
IntegerType getI32Type()
Definition: Builders.cpp:54
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
Definition: SPIRVToLLVM.cpp:98
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:262