MLIR  22.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include <cassert>
25 #include <limits>
26 #include <optional>
27 
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the offset of the value in `targetBits` representation.
37 ///
38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39 /// It's assumed to be non-negative.
40 ///
41 /// When accessing an element in the array treating as having elements of
42 /// `targetBits`, multiple values are loaded in the same time. The method
43 /// returns the offset where the `srcIdx` locates in the value. For example, if
44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
46 /// element has 8 bits.
47 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48  int targetBits, OpBuilder &builder) {
49  assert(targetBits % sourceBits == 0);
50  Type type = srcIdx.getType();
51  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54  auto srcBitsValue =
55  builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58 }
59 
60 /// Returns an adjusted spirv::AccessChainOp. Based on the
61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62 /// supported. During conversion if a memref of an unsupported type is used,
63 /// load/stores to this memref need to be modified to use a supported higher
64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66 /// the bits needed. The extraction of the actual bits needed are handled
67 /// separately. Note that this only works for a 1-D tensor.
68 static Value
70  spirv::AccessChainOp op, int sourceBits,
71  int targetBits, OpBuilder &builder) {
72  assert(targetBits % sourceBits == 0);
73  const auto loc = op.getLoc();
74  Value lastDim = op->getOperand(op.getNumOperands() - 1);
75  Type type = lastDim.getType();
76  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78  auto indices = llvm::to_vector<4>(op.getIndices());
79  // There are two elements if this is a 1-D tensor.
80  assert(indices.size() == 2);
81  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82  Type t = typeConverter.convertType(op.getComponentPtr().getType());
83  return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
84  indices);
85 }
86 
87 /// Casts the given `srcBool` into an integer of `dstType`.
88 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89  OpBuilder &builder) {
90  assert(srcBool.getType().isInteger(1));
91  if (dstType.isInteger(1))
92  return srcBool;
93  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96  zero);
97 }
98 
99 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100 /// to the type destination type, and masked.
101 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102  OpBuilder &builder) {
103  IntegerType dstType = cast<IntegerType>(mask.getType());
104  int targetBits = static_cast<int>(dstType.getWidth());
105  int valueBits = value.getType().getIntOrFloatBitWidth();
106  assert(valueBits <= targetBits);
107 
108  if (valueBits == 1) {
109  value = castBoolToIntN(loc, value, dstType, builder);
110  } else {
111  if (valueBits < targetBits) {
112  value = spirv::UConvertOp::create(
113  builder, loc, builder.getIntegerType(targetBits), value);
114  }
115 
116  value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117  }
118  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119  value, offset);
120 }
121 
122 /// Returns true if the allocations of memref `type` generated from `allocOp`
123 /// can be lowered to SPIR-V.
124 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128  return false;
129  } else if (isa<memref::AllocaOp>(allocOp)) {
130  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131  if (!sc || sc.getValue() != spirv::StorageClass::Function)
132  return false;
133  } else {
134  return false;
135  }
136 
137  // Currently only support static shape and int or float or vector of int or
138  // float element type.
139  if (!type.hasStaticShape())
140  return false;
141 
142  Type elementType = type.getElementType();
143  if (auto vecType = dyn_cast<VectorType>(elementType))
144  elementType = vecType.getElementType();
145  return elementType.isIntOrFloat();
146 }
147 
148 /// Returns the scope to use for atomic operations use for emulating store
149 /// operations of unsupported integer bitwidths, based on the memref
150 /// type. Returns std::nullopt on failure.
151 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
152  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153  switch (sc.getValue()) {
154  case spirv::StorageClass::StorageBuffer:
155  return spirv::Scope::Device;
156  case spirv::StorageClass::Workgroup:
157  return spirv::Scope::Workgroup;
158  default:
159  break;
160  }
161  return {};
162 }
163 
164 /// Casts the given `srcInt` into a boolean value.
165 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
166  if (srcInt.getType().isInteger(1))
167  return srcInt;
168 
169  auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
170  return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Operation conversion
175 //===----------------------------------------------------------------------===//
176 
177 // Note that DRR cannot be used for the patterns in this file: we may need to
178 // convert type along the way, which requires ConversionPattern. DRR generates
179 // normal RewritePattern.
180 
181 namespace {
182 
183 /// Converts memref.alloca to SPIR-V Function variables.
184 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
185 public:
186  using Base::Base;
187 
188  LogicalResult
189  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
190  ConversionPatternRewriter &rewriter) const override;
191 };
192 
193 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
194 /// to Workgroup memory when the size is constant. Note that this pattern needs
195 /// to be applied in a pass that runs at least at spirv.module scope since it
196 /// wil ladd global variables into the spirv.module.
197 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
198 public:
199  using Base::Base;
200 
201  LogicalResult
202  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override;
204 };
205 
206 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
207 class AtomicRMWOpPattern final
208  : public OpConversionPattern<memref::AtomicRMWOp> {
209 public:
210  using Base::Base;
211 
212  LogicalResult
213  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
214  ConversionPatternRewriter &rewriter) const override;
215 };
216 
217 /// Removed a deallocation if it is a supported allocation. Currently only
218 /// removes deallocation if the memory space is workgroup memory.
219 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
220 public:
221  using Base::Base;
222 
223  LogicalResult
224  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
225  ConversionPatternRewriter &rewriter) const override;
226 };
227 
228 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
229 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
230 public:
231  using Base::Base;
232 
233  LogicalResult
234  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
235  ConversionPatternRewriter &rewriter) const override;
236 };
237 
238 /// Converts memref.load to spirv.Load + spirv.AccessChain.
239 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
240 public:
241  using Base::Base;
242 
243  LogicalResult
244  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
245  ConversionPatternRewriter &rewriter) const override;
246 };
247 
248 /// Converts memref.load to spirv.Image + spirv.ImageFetch
249 class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
250 public:
251  using Base::Base;
252 
253  LogicalResult
254  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
255  ConversionPatternRewriter &rewriter) const override;
256 };
257 
258 /// Converts memref.store to spirv.Store on integers.
259 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
260 public:
261  using Base::Base;
262 
263  LogicalResult
264  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const override;
266 };
267 
268 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
269 class MemorySpaceCastOpPattern final
270  : public OpConversionPattern<memref::MemorySpaceCastOp> {
271 public:
272  using Base::Base;
273 
274  LogicalResult
275  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
276  ConversionPatternRewriter &rewriter) const override;
277 };
278 
279 /// Converts memref.store to spirv.Store.
280 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
281 public:
282  using Base::Base;
283 
284  LogicalResult
285  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const override;
287 };
288 
289 class ReinterpretCastPattern final
290  : public OpConversionPattern<memref::ReinterpretCastOp> {
291 public:
292  using Base::Base;
293 
294  LogicalResult
295  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
296  ConversionPatternRewriter &rewriter) const override;
297 };
298 
299 class CastPattern final : public OpConversionPattern<memref::CastOp> {
300 public:
301  using Base::Base;
302 
303  LogicalResult
304  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
305  ConversionPatternRewriter &rewriter) const override {
306  Value src = adaptor.getSource();
307  Type srcType = src.getType();
308 
309  const TypeConverter *converter = getTypeConverter();
310  Type dstType = converter->convertType(op.getType());
311  if (srcType != dstType)
312  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
313  diag << "types doesn't match: " << srcType << " and " << dstType;
314  });
315 
316  rewriter.replaceOp(op, src);
317  return success();
318  }
319 };
320 
321 /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
322 class ExtractAlignedPointerAsIndexOpPattern final
323  : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
324 public:
325  using Base::Base;
326 
327  LogicalResult
328  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
329  OpAdaptor adaptor,
330  ConversionPatternRewriter &rewriter) const override;
331 };
332 } // namespace
333 
334 //===----------------------------------------------------------------------===//
335 // AllocaOp
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult
339 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const {
341  MemRefType allocType = allocaOp.getType();
342  if (!isAllocationSupported(allocaOp, allocType))
343  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
344 
345  // Get the SPIR-V type for the allocation.
346  Type spirvType = getTypeConverter()->convertType(allocType);
347  if (!spirvType)
348  return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
349 
350  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
351  spirv::StorageClass::Function,
352  /*initializer=*/nullptr);
353  return success();
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // AllocOp
358 //===----------------------------------------------------------------------===//
359 
360 LogicalResult
361 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
362  ConversionPatternRewriter &rewriter) const {
363  MemRefType allocType = operation.getType();
364  if (!isAllocationSupported(operation, allocType))
365  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
366 
367  // Get the SPIR-V type for the allocation.
368  Type spirvType = getTypeConverter()->convertType(allocType);
369  if (!spirvType)
370  return rewriter.notifyMatchFailure(operation, "type conversion failed");
371 
372  // Insert spirv.GlobalVariable for this allocation.
373  Operation *parent =
374  SymbolTable::getNearestSymbolTable(operation->getParentOp());
375  if (!parent)
376  return failure();
377  Location loc = operation.getLoc();
378  spirv::GlobalVariableOp varOp;
379  {
380  OpBuilder::InsertionGuard guard(rewriter);
381  Block &entryBlock = *parent->getRegion(0).begin();
382  rewriter.setInsertionPointToStart(&entryBlock);
383  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
384  std::string varName =
385  std::string("__workgroup_mem__") +
386  std::to_string(std::distance(varOps.begin(), varOps.end()));
387  varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
388  /*initializer=*/nullptr);
389  }
390 
391  // Get pointer to global variable at the current scope.
392  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
393  return success();
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AllocOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult
401 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
402  OpAdaptor adaptor,
403  ConversionPatternRewriter &rewriter) const {
404  if (isa<FloatType>(atomicOp.getType()))
405  return rewriter.notifyMatchFailure(atomicOp,
406  "unimplemented floating-point case");
407 
408  auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
409  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
410  if (!scope)
411  return rewriter.notifyMatchFailure(atomicOp,
412  "unsupported memref memory space");
413 
414  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
415  Type resultType = typeConverter.convertType(atomicOp.getType());
416  if (!resultType)
417  return rewriter.notifyMatchFailure(atomicOp,
418  "failed to convert result type");
419 
420  auto loc = atomicOp.getLoc();
421  Value ptr =
422  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
423  adaptor.getIndices(), loc, rewriter);
424 
425  if (!ptr)
426  return failure();
427 
428 #define ATOMIC_CASE(kind, spirvOp) \
429  case arith::AtomicRMWKind::kind: \
430  rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
431  atomicOp, resultType, ptr, *scope, \
432  spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
433  break
434 
435  switch (atomicOp.getKind()) {
436  ATOMIC_CASE(addi, AtomicIAddOp);
437  ATOMIC_CASE(maxs, AtomicSMaxOp);
438  ATOMIC_CASE(maxu, AtomicUMaxOp);
439  ATOMIC_CASE(mins, AtomicSMinOp);
440  ATOMIC_CASE(minu, AtomicUMinOp);
441  ATOMIC_CASE(ori, AtomicOrOp);
442  ATOMIC_CASE(andi, AtomicAndOp);
443  default:
444  return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
445  }
446 
447 #undef ATOMIC_CASE
448 
449  return success();
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // DeallocOp
454 //===----------------------------------------------------------------------===//
455 
456 LogicalResult
457 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
458  OpAdaptor adaptor,
459  ConversionPatternRewriter &rewriter) const {
460  MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
461  if (!isAllocationSupported(operation, deallocType))
462  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
463  rewriter.eraseOp(operation);
464  return success();
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // LoadOp
469 //===----------------------------------------------------------------------===//
470 
472  spirv::MemoryAccessAttr memoryAccess;
473  IntegerAttr alignment;
474 };
475 
476 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
477 /// any.
478 static FailureOr<MemoryRequirements>
479 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
480  uint64_t preferredAlignment) {
481  if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
482  return failure();
483  }
484 
485  MLIRContext *ctx = accessedPtr.getContext();
486 
487  auto memoryAccess = spirv::MemoryAccess::None;
488  if (isNontemporal) {
489  memoryAccess = spirv::MemoryAccess::Nontemporal;
490  }
491 
492  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
493  bool mayOmitAlignment =
494  !preferredAlignment &&
495  ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
496  if (mayOmitAlignment) {
497  if (memoryAccess == spirv::MemoryAccess::None) {
498  return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
499  }
500  return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
501  IntegerAttr{}};
502  }
503 
504  // PhysicalStorageBuffers require the `Aligned` attribute.
505  // Other storage types may show an `Aligned` attribute.
506  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
507  if (!pointeeType)
508  return failure();
509 
510  // For scalar types, the alignment is determined by their size.
511  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
512  if (!sizeInBytes.has_value())
513  return failure();
514 
515  memoryAccess |= spirv::MemoryAccess::Aligned;
516  auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
517  auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
519  return MemoryRequirements{memAccessAttr, alignment};
520 }
521 
522 /// Given an accessed SPIR-V pointer and the original memref load/store
523 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
524 /// account the alignment attributes applied to the load/store op.
525 template <class LoadOrStoreOp>
526 static FailureOr<MemoryRequirements>
527 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
528  static_assert(
529  llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
530  "Must be called on either memref::LoadOp or memref::StoreOp");
531 
532  return calculateMemoryRequirements(accessedPtr,
533  loadOrStoreOp.getNontemporal(),
534  loadOrStoreOp.getAlignment().value_or(0));
535 }
536 
537 LogicalResult
538 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
539  ConversionPatternRewriter &rewriter) const {
540  auto loc = loadOp.getLoc();
541  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
542  if (!memrefType.getElementType().isSignlessInteger())
543  return failure();
544 
545  auto memorySpaceAttr =
546  dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
547  if (!memorySpaceAttr)
548  return rewriter.notifyMatchFailure(
549  loadOp, "missing memory space SPIR-V storage class attribute");
550 
551  if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
552  return rewriter.notifyMatchFailure(
553  loadOp,
554  "failed to lower memref in image storage class to storage buffer");
555 
556  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
557  Value accessChain =
558  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
559  adaptor.getIndices(), loc, rewriter);
560 
561  if (!accessChain)
562  return failure();
563 
564  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
565  bool isBool = srcBits == 1;
566  if (isBool)
567  srcBits = typeConverter.getOptions().boolNumBits;
568 
569  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
570  if (!pointerType)
571  return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
572 
573  Type pointeeType = pointerType.getPointeeType();
574  Type dstType;
575  if (typeConverter.allows(spirv::Capability::Kernel)) {
576  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
577  dstType = arrayType.getElementType();
578  else
579  dstType = pointeeType;
580  } else {
581  // For Vulkan we need to extract element from wrapping struct and array.
582  Type structElemType =
583  cast<spirv::StructType>(pointeeType).getElementType(0);
584  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
585  dstType = arrayType.getElementType();
586  else
587  dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
588  }
589  int dstBits = dstType.getIntOrFloatBitWidth();
590  assert(dstBits % srcBits == 0);
591 
592  // If the rewritten load op has the same bit width, use the loading value
593  // directly.
594  if (srcBits == dstBits) {
595  auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
596  if (failed(memoryRequirements))
597  return rewriter.notifyMatchFailure(
598  loadOp, "failed to determine memory requirements");
599 
600  auto [memoryAccess, alignment] = *memoryRequirements;
601  Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
602  memoryAccess, alignment);
603  if (isBool)
604  loadVal = castIntNToBool(loc, loadVal, rewriter);
605  rewriter.replaceOp(loadOp, loadVal);
606  return success();
607  }
608 
609  // Bitcasting is currently unsupported for Kernel capability /
610  // spirv.PtrAccessChain.
611  if (typeConverter.allows(spirv::Capability::Kernel))
612  return failure();
613 
614  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
615  if (!accessChainOp)
616  return failure();
617 
618  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
619  // still returns a linearized accessing. If the accessing is not linearized,
620  // there will be offset issues.
621  assert(accessChainOp.getIndices().size() == 2);
622  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
623  srcBits, dstBits, rewriter);
624  auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
625  if (failed(memoryRequirements))
626  return rewriter.notifyMatchFailure(
627  loadOp, "failed to determine memory requirements");
628 
629  auto [memoryAccess, alignment] = *memoryRequirements;
630  Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
631  memoryAccess, alignment);
632 
633  // Shift the bits to the rightmost.
634  // ____XXXX________ -> ____________XXXX
635  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
636  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
637  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
638  loc, spvLoadOp.getType(), spvLoadOp, offset);
639 
640  // Apply the mask to extract corresponding bits.
641  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
642  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
643  result =
644  rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
645 
646  // Apply sign extension on the loading value unconditionally. The signedness
647  // semantic is carried in the operator itself, we relies other pattern to
648  // handle the casting.
649  IntegerAttr shiftValueAttr =
650  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
651  Value shiftValue =
652  rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
653  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
654  result, shiftValue);
655  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
656  loc, dstType, result, shiftValue);
657 
658  rewriter.replaceOp(loadOp, result);
659 
660  assert(accessChainOp.use_empty());
661  rewriter.eraseOp(accessChainOp);
662 
663  return success();
664 }
665 
666 LogicalResult
667 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
668  ConversionPatternRewriter &rewriter) const {
669  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
670  if (memrefType.getElementType().isSignlessInteger())
671  return failure();
672 
673  auto memorySpaceAttr =
674  dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
675  if (!memorySpaceAttr)
676  return rewriter.notifyMatchFailure(
677  loadOp, "missing memory space SPIR-V storage class attribute");
678 
679  if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
680  return rewriter.notifyMatchFailure(
681  loadOp,
682  "failed to lower memref in image storage class to storage buffer");
683 
684  Value loadPtr = spirv::getElementPtr(
685  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
686  adaptor.getIndices(), loadOp.getLoc(), rewriter);
687 
688  if (!loadPtr)
689  return failure();
690 
691  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
692  if (failed(memoryRequirements))
693  return rewriter.notifyMatchFailure(
694  loadOp, "failed to determine memory requirements");
695 
696  auto [memoryAccess, alignment] = *memoryRequirements;
697  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
698  alignment);
699  return success();
700 }
701 
702 template <typename OpAdaptor>
703 static FailureOr<SmallVector<Value>>
704 extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
705  ConversionPatternRewriter &rewriter) {
706  // At present we only support linear "tiling" as specified in Vulkan, this
707  // means that texels are assumed to be laid out in memory in a row-major
708  // order. This allows us to support any memref layout that is a permutation of
709  // the dimensions. Future work will pass an optional image layout to the
710  // rewrite pattern so that we can support optimized target specific tilings.
711  SmallVector<Value> indices = adaptor.getIndices();
712  AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
713  if (!map.isPermutation())
714  return rewriter.notifyMatchFailure(
715  loadOp,
716  "Cannot lower memrefs with memory layout which is not a permutation");
717 
718  // The memrefs layout determines the dimension ordering so we need to follow
719  // the map to get the ordering of the dimensions/indices.
720  const unsigned dimCount = map.getNumDims();
721  SmallVector<Value, 3> coords(dimCount);
722  for (unsigned dim = 0; dim < dimCount; ++dim)
723  coords[map.getDimPosition(dim)] = indices[dim];
724 
725  // We need to reverse the coordinates because the memref layout is slowest to
726  // fastest moving and the vector coordinates for the image op is fastest to
727  // slowest moving.
728  return llvm::to_vector(llvm::reverse(coords));
729 }
730 
731 LogicalResult
732 ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
733  ConversionPatternRewriter &rewriter) const {
734  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
735 
736  auto memorySpaceAttr =
737  dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
738  if (!memorySpaceAttr)
739  return rewriter.notifyMatchFailure(
740  loadOp, "missing memory space SPIR-V storage class attribute");
741 
742  if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
743  return rewriter.notifyMatchFailure(
744  loadOp, "failed to lower memref in non-image storage class to image");
745 
746  Value loadPtr = adaptor.getMemref();
747  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
748  if (failed(memoryRequirements))
749  return rewriter.notifyMatchFailure(
750  loadOp, "failed to determine memory requirements");
751 
752  const auto [memoryAccess, alignment] = *memoryRequirements;
753 
754  if (!loadOp.getMemRefType().hasRank())
755  return rewriter.notifyMatchFailure(
756  loadOp, "cannot lower unranked memrefs to SPIR-V images");
757 
758  // We currently only support lowering of scalar memref elements to texels in
759  // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
760  // elements to texels in richer formats.
761  if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
762  return rewriter.notifyMatchFailure(
763  loadOp,
764  "cannot lower memrefs who's element type is not a SPIR-V scalar type"
765  "to SPIR-V images");
766 
767  // We currently only support sampled images since OpImageFetch does not work
768  // for plain images and the OpImageRead instruction needs to be materialized
769  // instead or texels need to be accessed via atomics through a texel pointer.
770  // Future work will generalize support to plain images.
771  auto convertedPointeeType = cast<spirv::PointerType>(
772  getTypeConverter()->convertType(loadOp.getMemRefType()));
773  if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
774  return rewriter.notifyMatchFailure(loadOp,
775  "cannot lower memrefs which do not "
776  "convert to SPIR-V sampled images");
777 
778  // Materialize the lowering.
779  Location loc = loadOp->getLoc();
780  auto imageLoadOp =
781  spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
782  // Extract the image from the sampled image.
783  auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
784 
785  // Build a vector of coordinates or just a scalar index if we have a 1D image.
786  Value coords;
787  if (memrefType.getRank() == 1) {
788  coords = adaptor.getIndices()[0];
789  } else {
790  FailureOr<SmallVector<Value>> maybeCoords =
791  extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
792  if (failed(maybeCoords))
793  return failure();
794  auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
795  adaptor.getIndices().getType()[0]);
796  coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
797  maybeCoords.value());
798  }
799 
800  // Fetch the value out of the image.
801  auto resultVectorType = VectorType::get({4}, loadOp.getType());
802  auto fetchOp = spirv::ImageFetchOp::create(
803  rewriter, loc, resultVectorType, imageOp, coords,
804  mlir::spirv::ImageOperandsAttr{}, ValueRange{});
805 
806  // Note that because OpImageFetch returns a rank 4 vector we need to extract
807  // the elements corresponding to the load which will since we only support the
808  // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
809  auto compositeExtractOp =
810  spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
811 
812  rewriter.replaceOp(loadOp, compositeExtractOp);
813  return success();
814 }
815 
816 LogicalResult
817 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
818  ConversionPatternRewriter &rewriter) const {
819  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
820  if (!memrefType.getElementType().isSignlessInteger())
821  return rewriter.notifyMatchFailure(storeOp,
822  "element type is not a signless int");
823 
824  auto loc = storeOp.getLoc();
825  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
826  Value accessChain =
827  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
828  adaptor.getIndices(), loc, rewriter);
829 
830  if (!accessChain)
831  return rewriter.notifyMatchFailure(
832  storeOp, "failed to convert element pointer type");
833 
834  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
835 
836  bool isBool = srcBits == 1;
837  if (isBool)
838  srcBits = typeConverter.getOptions().boolNumBits;
839 
840  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
841  if (!pointerType)
842  return rewriter.notifyMatchFailure(storeOp,
843  "failed to convert memref type");
844 
845  Type pointeeType = pointerType.getPointeeType();
846  IntegerType dstType;
847  if (typeConverter.allows(spirv::Capability::Kernel)) {
848  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
849  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
850  else
851  dstType = dyn_cast<IntegerType>(pointeeType);
852  } else {
853  // For Vulkan we need to extract element from wrapping struct and array.
854  Type structElemType =
855  cast<spirv::StructType>(pointeeType).getElementType(0);
856  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
857  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
858  else
859  dstType = dyn_cast<IntegerType>(
860  cast<spirv::RuntimeArrayType>(structElemType).getElementType());
861  }
862 
863  if (!dstType)
864  return rewriter.notifyMatchFailure(
865  storeOp, "failed to determine destination element type");
866 
867  int dstBits = static_cast<int>(dstType.getWidth());
868  assert(dstBits % srcBits == 0);
869 
870  if (srcBits == dstBits) {
871  auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
872  if (failed(memoryRequirements))
873  return rewriter.notifyMatchFailure(
874  storeOp, "failed to determine memory requirements");
875 
876  auto [memoryAccess, alignment] = *memoryRequirements;
877  Value storeVal = adaptor.getValue();
878  if (isBool)
879  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
880  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
881  memoryAccess, alignment);
882  return success();
883  }
884 
885  // Bitcasting is currently unsupported for Kernel capability /
886  // spirv.PtrAccessChain.
887  if (typeConverter.allows(spirv::Capability::Kernel))
888  return failure();
889 
890  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
891  if (!accessChainOp)
892  return failure();
893 
894  // Since there are multiple threads in the processing, the emulation will be
895  // done with atomic operations. E.g., if the stored value is i8, rewrite the
896  // StoreOp to:
897  // 1) load a 32-bit integer
898  // 2) clear 8 bits in the loaded value
899  // 3) set 8 bits in the loaded value
900  // 4) store 32-bit value back
901  //
902  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
903  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
904  // step.
905  assert(accessChainOp.getIndices().size() == 2);
906  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
907  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
908 
909  // Create a mask to clear the destination. E.g., if it is the second i8 in
910  // i32, 0xFFFF00FF is created.
911  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
912  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
913  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
914  loc, dstType, mask, offset);
915  clearBitsMask =
916  rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
917 
918  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
919  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
920  srcBits, dstBits, rewriter);
921  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
922  if (!scope)
923  return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
924 
925  Value result = spirv::AtomicAndOp::create(
926  rewriter, loc, dstType, adjustedPtr, *scope,
927  spirv::MemorySemantics::AcquireRelease, clearBitsMask);
928  result = spirv::AtomicOrOp::create(
929  rewriter, loc, dstType, adjustedPtr, *scope,
930  spirv::MemorySemantics::AcquireRelease, storeVal);
931 
932  // The AtomicOrOp has no side effect. Since it is already inserted, we can
933  // just remove the original StoreOp. Note that rewriter.replaceOp()
934  // doesn't work because it only accepts that the numbers of result are the
935  // same.
936  rewriter.eraseOp(storeOp);
937 
938  assert(accessChainOp.use_empty());
939  rewriter.eraseOp(accessChainOp);
940 
941  return success();
942 }
943 
944 //===----------------------------------------------------------------------===//
945 // MemorySpaceCastOp
946 //===----------------------------------------------------------------------===//
947 
948 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
949  memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
950  ConversionPatternRewriter &rewriter) const {
951  Location loc = addrCastOp.getLoc();
952  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
953  if (!typeConverter.allows(spirv::Capability::Kernel))
954  return rewriter.notifyMatchFailure(
955  loc, "address space casts require kernel capability");
956 
957  auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
958  if (!sourceType)
959  return rewriter.notifyMatchFailure(
960  loc, "SPIR-V lowering requires ranked memref types");
961  auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
962 
963  auto sourceStorageClassAttr =
964  dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
965  if (!sourceStorageClassAttr)
966  return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
967  diag << "source address space " << sourceType.getMemorySpace()
968  << " must be a SPIR-V storage class";
969  });
970  auto resultStorageClassAttr =
971  dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
972  if (!resultStorageClassAttr)
973  return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
974  diag << "result address space " << resultType.getMemorySpace()
975  << " must be a SPIR-V storage class";
976  });
977 
978  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
979  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
980 
981  Value result = adaptor.getSource();
982  Type resultPtrType = typeConverter.convertType(resultType);
983  if (!resultPtrType)
984  return rewriter.notifyMatchFailure(addrCastOp,
985  "failed to convert memref type");
986 
987  Type genericPtrType = resultPtrType;
988  // SPIR-V doesn't have a general address space cast operation. Instead, it has
989  // conversions to and from generic pointers. To implement the general case,
990  // we use specific-to-generic conversions when the source class is not
991  // generic. Then when the result storage class is not generic, we convert the
992  // generic pointer (either the input on ar intermediate result) to that
993  // class. This also means that we'll need the intermediate generic pointer
994  // type if neither the source or destination have it.
995  if (sourceSc != spirv::StorageClass::Generic &&
996  resultSc != spirv::StorageClass::Generic) {
997  Type intermediateType =
998  MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
999  sourceType.getLayout(),
1000  rewriter.getAttr<spirv::StorageClassAttr>(
1001  spirv::StorageClass::Generic));
1002  genericPtrType = typeConverter.convertType(intermediateType);
1003  }
1004  if (sourceSc != spirv::StorageClass::Generic) {
1005  result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1006  result);
1007  }
1008  if (resultSc != spirv::StorageClass::Generic) {
1009  result =
1010  spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
1011  }
1012  rewriter.replaceOp(addrCastOp, result);
1013  return success();
1014 }
1015 
1016 LogicalResult
1017 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1018  ConversionPatternRewriter &rewriter) const {
1019  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1020  if (memrefType.getElementType().isSignlessInteger())
1021  return rewriter.notifyMatchFailure(storeOp, "signless int");
1022  auto storePtr = spirv::getElementPtr(
1023  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1024  adaptor.getIndices(), storeOp.getLoc(), rewriter);
1025 
1026  if (!storePtr)
1027  return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
1028 
1029  auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
1030  if (failed(memoryRequirements))
1031  return rewriter.notifyMatchFailure(
1032  storeOp, "failed to determine memory requirements");
1033 
1034  auto [memoryAccess, alignment] = *memoryRequirements;
1035  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1036  storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1037  return success();
1038 }
1039 
1040 LogicalResult ReinterpretCastPattern::matchAndRewrite(
1041  memref::ReinterpretCastOp op, OpAdaptor adaptor,
1042  ConversionPatternRewriter &rewriter) const {
1043  Value src = adaptor.getSource();
1044  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
1045 
1046  if (!srcType)
1047  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1048  diag << "invalid src type " << src.getType();
1049  });
1050 
1051  const TypeConverter *converter = getTypeConverter();
1052 
1053  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1054  if (dstType != srcType)
1055  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1056  diag << "invalid dst type " << op.getType();
1057  });
1058 
1059  OpFoldResult offset =
1060  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1061  .front();
1062  if (isZeroInteger(offset)) {
1063  rewriter.replaceOp(op, src);
1064  return success();
1065  }
1066 
1067  Type intType = converter->convertType(rewriter.getIndexType());
1068  if (!intType)
1069  return rewriter.notifyMatchFailure(op, "failed to convert index type");
1070 
1071  Location loc = op.getLoc();
1072  auto offsetValue = [&]() -> Value {
1073  if (auto val = dyn_cast<Value>(offset))
1074  return val;
1075 
1076  int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1077  Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1078  return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1079  }();
1080 
1081  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1082  op, src, offsetValue, ValueRange());
1083  return success();
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // ExtractAlignedPointerAsIndexOp
1088 //===----------------------------------------------------------------------===//
1089 
1090 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1091  memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1092  ConversionPatternRewriter &rewriter) const {
1093  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1094  Type indexType = typeConverter.getIndexType();
1095  rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1096  adaptor.getSource());
1097  return success();
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // Pattern population
1102 //===----------------------------------------------------------------------===//
1103 
1104 namespace mlir {
1107  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1108  DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1109  IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1110  StoreOpPattern, ReinterpretCastPattern, CastPattern,
1111  ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1112  patterns.getContext());
1113 }
1114 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Type getElementType(Type type)
Determine the element type of type.
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
@ None
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
unsigned getNumDims() const
Definition: AffineMap.cpp:390
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
IndexType getIndexType()
Definition: Builders.cpp:51
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess