MLIR  22.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include <cassert>
25 #include <optional>
26 
27 #define DEBUG_TYPE "memref-to-spirv-pattern"
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Returns the offset of the value in `targetBits` representation.
36 ///
37 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
38 /// It's assumed to be non-negative.
39 ///
40 /// When accessing an element in the array treating as having elements of
41 /// `targetBits`, multiple values are loaded in the same time. The method
42 /// returns the offset where the `srcIdx` locates in the value. For example, if
43 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
44 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
45 /// element has 8 bits.
46 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
47  int targetBits, OpBuilder &builder) {
48  assert(targetBits % sourceBits == 0);
49  Type type = srcIdx.getType();
50  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
51  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
52  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
53  auto srcBitsValue =
54  builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
55  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
56  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
57 }
58 
59 /// Returns an adjusted spirv::AccessChainOp. Based on the
60 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
61 /// supported. During conversion if a memref of an unsupported type is used,
62 /// load/stores to this memref need to be modified to use a supported higher
63 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
64 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
65 /// the bits needed. The extraction of the actual bits needed are handled
66 /// separately. Note that this only works for a 1-D tensor.
67 static Value
69  spirv::AccessChainOp op, int sourceBits,
70  int targetBits, OpBuilder &builder) {
71  assert(targetBits % sourceBits == 0);
72  const auto loc = op.getLoc();
73  Value lastDim = op->getOperand(op.getNumOperands() - 1);
74  Type type = lastDim.getType();
75  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
76  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
77  auto indices = llvm::to_vector<4>(op.getIndices());
78  // There are two elements if this is a 1-D tensor.
79  assert(indices.size() == 2);
80  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
81  Type t = typeConverter.convertType(op.getComponentPtr().getType());
82  return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
83 }
84 
85 /// Casts the given `srcBool` into an integer of `dstType`.
86 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
87  OpBuilder &builder) {
88  assert(srcBool.getType().isInteger(1));
89  if (dstType.isInteger(1))
90  return srcBool;
91  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
92  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
93  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
94  zero);
95 }
96 
97 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
98 /// to the type destination type, and masked.
99 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
100  OpBuilder &builder) {
101  IntegerType dstType = cast<IntegerType>(mask.getType());
102  int targetBits = static_cast<int>(dstType.getWidth());
103  int valueBits = value.getType().getIntOrFloatBitWidth();
104  assert(valueBits <= targetBits);
105 
106  if (valueBits == 1) {
107  value = castBoolToIntN(loc, value, dstType, builder);
108  } else {
109  if (valueBits < targetBits) {
110  value = builder.create<spirv::UConvertOp>(
111  loc, builder.getIntegerType(targetBits), value);
112  }
113 
114  value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
115  }
116  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
117  value, offset);
118 }
119 
120 /// Returns true if the allocations of memref `type` generated from `allocOp`
121 /// can be lowered to SPIR-V.
122 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
123  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
124  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
125  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
126  return false;
127  } else if (isa<memref::AllocaOp>(allocOp)) {
128  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
129  if (!sc || sc.getValue() != spirv::StorageClass::Function)
130  return false;
131  } else {
132  return false;
133  }
134 
135  // Currently only support static shape and int or float or vector of int or
136  // float element type.
137  if (!type.hasStaticShape())
138  return false;
139 
140  Type elementType = type.getElementType();
141  if (auto vecType = dyn_cast<VectorType>(elementType))
142  elementType = vecType.getElementType();
143  return elementType.isIntOrFloat();
144 }
145 
146 /// Returns the scope to use for atomic operations use for emulating store
147 /// operations of unsupported integer bitwidths, based on the memref
148 /// type. Returns std::nullopt on failure.
149 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
150  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
151  switch (sc.getValue()) {
152  case spirv::StorageClass::StorageBuffer:
153  return spirv::Scope::Device;
154  case spirv::StorageClass::Workgroup:
155  return spirv::Scope::Workgroup;
156  default:
157  break;
158  }
159  return {};
160 }
161 
162 /// Casts the given `srcInt` into a boolean value.
163 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
164  if (srcInt.getType().isInteger(1))
165  return srcInt;
166 
167  auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
168  return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Operation conversion
173 //===----------------------------------------------------------------------===//
174 
175 // Note that DRR cannot be used for the patterns in this file: we may need to
176 // convert type along the way, which requires ConversionPattern. DRR generates
177 // normal RewritePattern.
178 
179 namespace {
180 
181 /// Converts memref.alloca to SPIR-V Function variables.
182 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
183 public:
185 
186  LogicalResult
187  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
188  ConversionPatternRewriter &rewriter) const override;
189 };
190 
191 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
192 /// to Workgroup memory when the size is constant. Note that this pattern needs
193 /// to be applied in a pass that runs at least at spirv.module scope since it
194 /// wil ladd global variables into the spirv.module.
195 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
196 public:
198 
199  LogicalResult
200  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
201  ConversionPatternRewriter &rewriter) const override;
202 };
203 
204 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
205 class AtomicRMWOpPattern final
206  : public OpConversionPattern<memref::AtomicRMWOp> {
207 public:
209 
210  LogicalResult
211  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
212  ConversionPatternRewriter &rewriter) const override;
213 };
214 
215 /// Removed a deallocation if it is a supported allocation. Currently only
216 /// removes deallocation if the memory space is workgroup memory.
217 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
218 public:
220 
221  LogicalResult
222  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
223  ConversionPatternRewriter &rewriter) const override;
224 };
225 
226 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
227 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
228 public:
230 
231  LogicalResult
232  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
233  ConversionPatternRewriter &rewriter) const override;
234 };
235 
236 /// Converts memref.load to spirv.Load + spirv.AccessChain.
237 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
238 public:
240 
241  LogicalResult
242  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
243  ConversionPatternRewriter &rewriter) const override;
244 };
245 
246 /// Converts memref.store to spirv.Store on integers.
247 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
248 public:
250 
251  LogicalResult
252  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
253  ConversionPatternRewriter &rewriter) const override;
254 };
255 
256 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
257 class MemorySpaceCastOpPattern final
258  : public OpConversionPattern<memref::MemorySpaceCastOp> {
259 public:
261 
262  LogicalResult
263  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
264  ConversionPatternRewriter &rewriter) const override;
265 };
266 
267 /// Converts memref.store to spirv.Store.
268 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
269 public:
271 
272  LogicalResult
273  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
274  ConversionPatternRewriter &rewriter) const override;
275 };
276 
277 class ReinterpretCastPattern final
278  : public OpConversionPattern<memref::ReinterpretCastOp> {
279 public:
281 
282  LogicalResult
283  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
284  ConversionPatternRewriter &rewriter) const override;
285 };
286 
287 class CastPattern final : public OpConversionPattern<memref::CastOp> {
288 public:
290 
291  LogicalResult
292  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
293  ConversionPatternRewriter &rewriter) const override {
294  Value src = adaptor.getSource();
295  Type srcType = src.getType();
296 
297  const TypeConverter *converter = getTypeConverter();
298  Type dstType = converter->convertType(op.getType());
299  if (srcType != dstType)
300  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
301  diag << "types doesn't match: " << srcType << " and " << dstType;
302  });
303 
304  rewriter.replaceOp(op, src);
305  return success();
306  }
307 };
308 
309 /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
310 class ExtractAlignedPointerAsIndexOpPattern final
311  : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
312 public:
314 
315  LogicalResult
316  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
317  OpAdaptor adaptor,
318  ConversionPatternRewriter &rewriter) const override;
319 };
320 } // namespace
321 
322 //===----------------------------------------------------------------------===//
323 // AllocaOp
324 //===----------------------------------------------------------------------===//
325 
326 LogicalResult
327 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
328  ConversionPatternRewriter &rewriter) const {
329  MemRefType allocType = allocaOp.getType();
330  if (!isAllocationSupported(allocaOp, allocType))
331  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
332 
333  // Get the SPIR-V type for the allocation.
334  Type spirvType = getTypeConverter()->convertType(allocType);
335  if (!spirvType)
336  return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
337 
338  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
339  spirv::StorageClass::Function,
340  /*initializer=*/nullptr);
341  return success();
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // AllocOp
346 //===----------------------------------------------------------------------===//
347 
348 LogicalResult
349 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
350  ConversionPatternRewriter &rewriter) const {
351  MemRefType allocType = operation.getType();
352  if (!isAllocationSupported(operation, allocType))
353  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
354 
355  // Get the SPIR-V type for the allocation.
356  Type spirvType = getTypeConverter()->convertType(allocType);
357  if (!spirvType)
358  return rewriter.notifyMatchFailure(operation, "type conversion failed");
359 
360  // Insert spirv.GlobalVariable for this allocation.
361  Operation *parent =
362  SymbolTable::getNearestSymbolTable(operation->getParentOp());
363  if (!parent)
364  return failure();
365  Location loc = operation.getLoc();
366  spirv::GlobalVariableOp varOp;
367  {
368  OpBuilder::InsertionGuard guard(rewriter);
369  Block &entryBlock = *parent->getRegion(0).begin();
370  rewriter.setInsertionPointToStart(&entryBlock);
371  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
372  std::string varName =
373  std::string("__workgroup_mem__") +
374  std::to_string(std::distance(varOps.begin(), varOps.end()));
375  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
376  /*initializer=*/nullptr);
377  }
378 
379  // Get pointer to global variable at the current scope.
380  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
381  return success();
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // AllocOp
386 //===----------------------------------------------------------------------===//
387 
388 LogicalResult
389 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
390  OpAdaptor adaptor,
391  ConversionPatternRewriter &rewriter) const {
392  if (isa<FloatType>(atomicOp.getType()))
393  return rewriter.notifyMatchFailure(atomicOp,
394  "unimplemented floating-point case");
395 
396  auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
397  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
398  if (!scope)
399  return rewriter.notifyMatchFailure(atomicOp,
400  "unsupported memref memory space");
401 
402  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
403  Type resultType = typeConverter.convertType(atomicOp.getType());
404  if (!resultType)
405  return rewriter.notifyMatchFailure(atomicOp,
406  "failed to convert result type");
407 
408  auto loc = atomicOp.getLoc();
409  Value ptr =
410  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
411  adaptor.getIndices(), loc, rewriter);
412 
413  if (!ptr)
414  return failure();
415 
416 #define ATOMIC_CASE(kind, spirvOp) \
417  case arith::AtomicRMWKind::kind: \
418  rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
419  atomicOp, resultType, ptr, *scope, \
420  spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
421  break
422 
423  switch (atomicOp.getKind()) {
424  ATOMIC_CASE(addi, AtomicIAddOp);
425  ATOMIC_CASE(maxs, AtomicSMaxOp);
426  ATOMIC_CASE(maxu, AtomicUMaxOp);
427  ATOMIC_CASE(mins, AtomicSMinOp);
428  ATOMIC_CASE(minu, AtomicUMinOp);
429  ATOMIC_CASE(ori, AtomicOrOp);
430  ATOMIC_CASE(andi, AtomicAndOp);
431  default:
432  return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
433  }
434 
435 #undef ATOMIC_CASE
436 
437  return success();
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // DeallocOp
442 //===----------------------------------------------------------------------===//
443 
444 LogicalResult
445 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
446  OpAdaptor adaptor,
447  ConversionPatternRewriter &rewriter) const {
448  MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
449  if (!isAllocationSupported(operation, deallocType))
450  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
451  rewriter.eraseOp(operation);
452  return success();
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // LoadOp
457 //===----------------------------------------------------------------------===//
458 
460  spirv::MemoryAccessAttr memoryAccess;
461  IntegerAttr alignment;
462 };
463 
464 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
465 /// any.
466 static FailureOr<MemoryRequirements>
467 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
468  MLIRContext *ctx = accessedPtr.getContext();
469 
470  auto memoryAccess = spirv::MemoryAccess::None;
471  if (isNontemporal) {
472  memoryAccess = spirv::MemoryAccess::Nontemporal;
473  }
474 
475  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
476  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
477  if (memoryAccess == spirv::MemoryAccess::None) {
478  return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
479  }
480  return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
481  IntegerAttr{}};
482  }
483 
484  // PhysicalStorageBuffers require the `Aligned` attribute.
485  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
486  if (!pointeeType)
487  return failure();
488 
489  // For scalar types, the alignment is determined by their size.
490  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
491  if (!sizeInBytes.has_value())
492  return failure();
493 
494  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
495  auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
496  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
497  return MemoryRequirements{memAccessAttr, alignment};
498 }
499 
500 /// Given an accessed SPIR-V pointer and the original memref load/store
501 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
502 /// account the alignment attributes applied to the load/store op.
503 template <class LoadOrStoreOp>
504 static FailureOr<MemoryRequirements>
505 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
506  static_assert(
507  llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
508  "Must be called on either memref::LoadOp or memref::StoreOp");
509 
510  Operation *memrefAccessOp = loadOrStoreOp.getOperation();
511  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
512  spirv::attributeName<spirv::MemoryAccess>());
513  auto memrefAlignment =
514  memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
515  if (memrefMemAccess && memrefAlignment)
516  return MemoryRequirements{memrefMemAccess, memrefAlignment};
517 
518  return calculateMemoryRequirements(accessedPtr,
519  loadOrStoreOp.getNontemporal());
520 }
521 
522 LogicalResult
523 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
524  ConversionPatternRewriter &rewriter) const {
525  auto loc = loadOp.getLoc();
526  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
527  if (!memrefType.getElementType().isSignlessInteger())
528  return failure();
529 
530  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
531  Value accessChain =
532  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
533  adaptor.getIndices(), loc, rewriter);
534 
535  if (!accessChain)
536  return failure();
537 
538  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
539  bool isBool = srcBits == 1;
540  if (isBool)
541  srcBits = typeConverter.getOptions().boolNumBits;
542 
543  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
544  if (!pointerType)
545  return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
546 
547  Type pointeeType = pointerType.getPointeeType();
548  Type dstType;
549  if (typeConverter.allows(spirv::Capability::Kernel)) {
550  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
551  dstType = arrayType.getElementType();
552  else
553  dstType = pointeeType;
554  } else {
555  // For Vulkan we need to extract element from wrapping struct and array.
556  Type structElemType =
557  cast<spirv::StructType>(pointeeType).getElementType(0);
558  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
559  dstType = arrayType.getElementType();
560  else
561  dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
562  }
563  int dstBits = dstType.getIntOrFloatBitWidth();
564  assert(dstBits % srcBits == 0);
565 
566  // If the rewritten load op has the same bit width, use the loading value
567  // directly.
568  if (srcBits == dstBits) {
569  auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
570  if (failed(memoryRequirements))
571  return rewriter.notifyMatchFailure(
572  loadOp, "failed to determine memory requirements");
573 
574  auto [memoryAccess, alignment] = *memoryRequirements;
575  Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
576  memoryAccess, alignment);
577  if (isBool)
578  loadVal = castIntNToBool(loc, loadVal, rewriter);
579  rewriter.replaceOp(loadOp, loadVal);
580  return success();
581  }
582 
583  // Bitcasting is currently unsupported for Kernel capability /
584  // spirv.PtrAccessChain.
585  if (typeConverter.allows(spirv::Capability::Kernel))
586  return failure();
587 
588  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
589  if (!accessChainOp)
590  return failure();
591 
592  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
593  // still returns a linearized accessing. If the accessing is not linearized,
594  // there will be offset issues.
595  assert(accessChainOp.getIndices().size() == 2);
596  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
597  srcBits, dstBits, rewriter);
598  auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
599  if (failed(memoryRequirements))
600  return rewriter.notifyMatchFailure(
601  loadOp, "failed to determine memory requirements");
602 
603  auto [memoryAccess, alignment] = *memoryRequirements;
604  Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
605  memoryAccess, alignment);
606 
607  // Shift the bits to the rightmost.
608  // ____XXXX________ -> ____________XXXX
609  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
610  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
611  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
612  loc, spvLoadOp.getType(), spvLoadOp, offset);
613 
614  // Apply the mask to extract corresponding bits.
615  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
616  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
617  result =
618  rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
619 
620  // Apply sign extension on the loading value unconditionally. The signedness
621  // semantic is carried in the operator itself, we relies other pattern to
622  // handle the casting.
623  IntegerAttr shiftValueAttr =
624  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
625  Value shiftValue =
626  rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
627  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
628  result, shiftValue);
629  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
630  loc, dstType, result, shiftValue);
631 
632  rewriter.replaceOp(loadOp, result);
633 
634  assert(accessChainOp.use_empty());
635  rewriter.eraseOp(accessChainOp);
636 
637  return success();
638 }
639 
640 LogicalResult
641 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
642  ConversionPatternRewriter &rewriter) const {
643  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
644  if (memrefType.getElementType().isSignlessInteger())
645  return failure();
646  Value loadPtr = spirv::getElementPtr(
647  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
648  adaptor.getIndices(), loadOp.getLoc(), rewriter);
649 
650  if (!loadPtr)
651  return failure();
652 
653  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
654  if (failed(memoryRequirements))
655  return rewriter.notifyMatchFailure(
656  loadOp, "failed to determine memory requirements");
657 
658  auto [memoryAccess, alignment] = *memoryRequirements;
659  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
660  alignment);
661  return success();
662 }
663 
664 LogicalResult
665 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
666  ConversionPatternRewriter &rewriter) const {
667  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
668  if (!memrefType.getElementType().isSignlessInteger())
669  return rewriter.notifyMatchFailure(storeOp,
670  "element type is not a signless int");
671 
672  auto loc = storeOp.getLoc();
673  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
674  Value accessChain =
675  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
676  adaptor.getIndices(), loc, rewriter);
677 
678  if (!accessChain)
679  return rewriter.notifyMatchFailure(
680  storeOp, "failed to convert element pointer type");
681 
682  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
683 
684  bool isBool = srcBits == 1;
685  if (isBool)
686  srcBits = typeConverter.getOptions().boolNumBits;
687 
688  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
689  if (!pointerType)
690  return rewriter.notifyMatchFailure(storeOp,
691  "failed to convert memref type");
692 
693  Type pointeeType = pointerType.getPointeeType();
694  IntegerType dstType;
695  if (typeConverter.allows(spirv::Capability::Kernel)) {
696  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
697  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
698  else
699  dstType = dyn_cast<IntegerType>(pointeeType);
700  } else {
701  // For Vulkan we need to extract element from wrapping struct and array.
702  Type structElemType =
703  cast<spirv::StructType>(pointeeType).getElementType(0);
704  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
705  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
706  else
707  dstType = dyn_cast<IntegerType>(
708  cast<spirv::RuntimeArrayType>(structElemType).getElementType());
709  }
710 
711  if (!dstType)
712  return rewriter.notifyMatchFailure(
713  storeOp, "failed to determine destination element type");
714 
715  int dstBits = static_cast<int>(dstType.getWidth());
716  assert(dstBits % srcBits == 0);
717 
718  if (srcBits == dstBits) {
719  auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
720  if (failed(memoryRequirements))
721  return rewriter.notifyMatchFailure(
722  storeOp, "failed to determine memory requirements");
723 
724  auto [memoryAccess, alignment] = *memoryRequirements;
725  Value storeVal = adaptor.getValue();
726  if (isBool)
727  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
728  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
729  memoryAccess, alignment);
730  return success();
731  }
732 
733  // Bitcasting is currently unsupported for Kernel capability /
734  // spirv.PtrAccessChain.
735  if (typeConverter.allows(spirv::Capability::Kernel))
736  return failure();
737 
738  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
739  if (!accessChainOp)
740  return failure();
741 
742  // Since there are multiple threads in the processing, the emulation will be
743  // done with atomic operations. E.g., if the stored value is i8, rewrite the
744  // StoreOp to:
745  // 1) load a 32-bit integer
746  // 2) clear 8 bits in the loaded value
747  // 3) set 8 bits in the loaded value
748  // 4) store 32-bit value back
749  //
750  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
751  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
752  // step.
753  assert(accessChainOp.getIndices().size() == 2);
754  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
755  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
756 
757  // Create a mask to clear the destination. E.g., if it is the second i8 in
758  // i32, 0xFFFF00FF is created.
759  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
760  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
761  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
762  loc, dstType, mask, offset);
763  clearBitsMask =
764  rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
765 
766  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
767  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
768  srcBits, dstBits, rewriter);
769  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
770  if (!scope)
771  return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
772 
773  Value result = rewriter.create<spirv::AtomicAndOp>(
774  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
775  clearBitsMask);
776  result = rewriter.create<spirv::AtomicOrOp>(
777  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
778  storeVal);
779 
780  // The AtomicOrOp has no side effect. Since it is already inserted, we can
781  // just remove the original StoreOp. Note that rewriter.replaceOp()
782  // doesn't work because it only accepts that the numbers of result are the
783  // same.
784  rewriter.eraseOp(storeOp);
785 
786  assert(accessChainOp.use_empty());
787  rewriter.eraseOp(accessChainOp);
788 
789  return success();
790 }
791 
792 //===----------------------------------------------------------------------===//
793 // MemorySpaceCastOp
794 //===----------------------------------------------------------------------===//
795 
796 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
797  memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
798  ConversionPatternRewriter &rewriter) const {
799  Location loc = addrCastOp.getLoc();
800  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
801  if (!typeConverter.allows(spirv::Capability::Kernel))
802  return rewriter.notifyMatchFailure(
803  loc, "address space casts require kernel capability");
804 
805  auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
806  if (!sourceType)
807  return rewriter.notifyMatchFailure(
808  loc, "SPIR-V lowering requires ranked memref types");
809  auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
810 
811  auto sourceStorageClassAttr =
812  dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
813  if (!sourceStorageClassAttr)
814  return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
815  diag << "source address space " << sourceType.getMemorySpace()
816  << " must be a SPIR-V storage class";
817  });
818  auto resultStorageClassAttr =
819  dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
820  if (!resultStorageClassAttr)
821  return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
822  diag << "result address space " << resultType.getMemorySpace()
823  << " must be a SPIR-V storage class";
824  });
825 
826  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
827  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
828 
829  Value result = adaptor.getSource();
830  Type resultPtrType = typeConverter.convertType(resultType);
831  if (!resultPtrType)
832  return rewriter.notifyMatchFailure(addrCastOp,
833  "failed to convert memref type");
834 
835  Type genericPtrType = resultPtrType;
836  // SPIR-V doesn't have a general address space cast operation. Instead, it has
837  // conversions to and from generic pointers. To implement the general case,
838  // we use specific-to-generic conversions when the source class is not
839  // generic. Then when the result storage class is not generic, we convert the
840  // generic pointer (either the input on ar intermediate result) to that
841  // class. This also means that we'll need the intermediate generic pointer
842  // type if neither the source or destination have it.
843  if (sourceSc != spirv::StorageClass::Generic &&
844  resultSc != spirv::StorageClass::Generic) {
845  Type intermediateType =
846  MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
847  sourceType.getLayout(),
848  rewriter.getAttr<spirv::StorageClassAttr>(
849  spirv::StorageClass::Generic));
850  genericPtrType = typeConverter.convertType(intermediateType);
851  }
852  if (sourceSc != spirv::StorageClass::Generic) {
853  result =
854  rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
855  }
856  if (resultSc != spirv::StorageClass::Generic) {
857  result =
858  rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
859  }
860  rewriter.replaceOp(addrCastOp, result);
861  return success();
862 }
863 
864 LogicalResult
865 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
866  ConversionPatternRewriter &rewriter) const {
867  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
868  if (memrefType.getElementType().isSignlessInteger())
869  return rewriter.notifyMatchFailure(storeOp, "signless int");
870  auto storePtr = spirv::getElementPtr(
871  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
872  adaptor.getIndices(), storeOp.getLoc(), rewriter);
873 
874  if (!storePtr)
875  return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
876 
877  auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
878  if (failed(memoryRequirements))
879  return rewriter.notifyMatchFailure(
880  storeOp, "failed to determine memory requirements");
881 
882  auto [memoryAccess, alignment] = *memoryRequirements;
883  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
884  storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
885  return success();
886 }
887 
888 LogicalResult ReinterpretCastPattern::matchAndRewrite(
889  memref::ReinterpretCastOp op, OpAdaptor adaptor,
890  ConversionPatternRewriter &rewriter) const {
891  Value src = adaptor.getSource();
892  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
893 
894  if (!srcType)
895  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
896  diag << "invalid src type " << src.getType();
897  });
898 
899  const TypeConverter *converter = getTypeConverter();
900 
901  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
902  if (dstType != srcType)
903  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
904  diag << "invalid dst type " << op.getType();
905  });
906 
907  OpFoldResult offset =
908  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
909  .front();
910  if (isZeroInteger(offset)) {
911  rewriter.replaceOp(op, src);
912  return success();
913  }
914 
915  Type intType = converter->convertType(rewriter.getIndexType());
916  if (!intType)
917  return rewriter.notifyMatchFailure(op, "failed to convert index type");
918 
919  Location loc = op.getLoc();
920  auto offsetValue = [&]() -> Value {
921  if (auto val = dyn_cast<Value>(offset))
922  return val;
923 
924  int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
925  Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
926  return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
927  }();
928 
929  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
930  op, src, offsetValue, ValueRange());
931  return success();
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // ExtractAlignedPointerAsIndexOp
936 //===----------------------------------------------------------------------===//
937 
938 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
939  memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
940  ConversionPatternRewriter &rewriter) const {
941  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
942  Type indexType = typeConverter.getIndexType();
943  rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
944  adaptor.getSource());
945  return success();
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // Pattern population
950 //===----------------------------------------------------------------------===//
951 
952 namespace mlir {
955  patterns
956  .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
957  DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
958  MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
959  CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
960  typeConverter, patterns.getContext());
961 }
962 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
@ None
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:186
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
IndexType getIndexType()
Definition: Builders.cpp:50
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess