MLIR  21.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include "llvm/Support/Debug.h"
25 #include <cassert>
26 #include <optional>
27 
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the offset of the value in `targetBits` representation.
37 ///
38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39 /// It's assumed to be non-negative.
40 ///
41 /// When accessing an element in the array treating as having elements of
42 /// `targetBits`, multiple values are loaded in the same time. The method
43 /// returns the offset where the `srcIdx` locates in the value. For example, if
44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
46 /// element has 8 bits.
47 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48  int targetBits, OpBuilder &builder) {
49  assert(targetBits % sourceBits == 0);
50  Type type = srcIdx.getType();
51  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54  auto srcBitsValue =
55  builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58 }
59 
60 /// Returns an adjusted spirv::AccessChainOp. Based on the
61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62 /// supported. During conversion if a memref of an unsupported type is used,
63 /// load/stores to this memref need to be modified to use a supported higher
64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66 /// the bits needed. The extraction of the actual bits needed are handled
67 /// separately. Note that this only works for a 1-D tensor.
68 static Value
70  spirv::AccessChainOp op, int sourceBits,
71  int targetBits, OpBuilder &builder) {
72  assert(targetBits % sourceBits == 0);
73  const auto loc = op.getLoc();
74  Value lastDim = op->getOperand(op.getNumOperands() - 1);
75  Type type = lastDim.getType();
76  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78  auto indices = llvm::to_vector<4>(op.getIndices());
79  // There are two elements if this is a 1-D tensor.
80  assert(indices.size() == 2);
81  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82  Type t = typeConverter.convertType(op.getComponentPtr().getType());
83  return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
84 }
85 
86 /// Casts the given `srcBool` into an integer of `dstType`.
87 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
88  OpBuilder &builder) {
89  assert(srcBool.getType().isInteger(1));
90  if (dstType.isInteger(1))
91  return srcBool;
92  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
93  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
95  zero);
96 }
97 
98 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
99 /// to the type destination type, and masked.
100 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
101  OpBuilder &builder) {
102  IntegerType dstType = cast<IntegerType>(mask.getType());
103  int targetBits = static_cast<int>(dstType.getWidth());
104  int valueBits = value.getType().getIntOrFloatBitWidth();
105  assert(valueBits <= targetBits);
106 
107  if (valueBits == 1) {
108  value = castBoolToIntN(loc, value, dstType, builder);
109  } else {
110  if (valueBits < targetBits) {
111  value = builder.create<spirv::UConvertOp>(
112  loc, builder.getIntegerType(targetBits), value);
113  }
114 
115  value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
116  }
117  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
118  value, offset);
119 }
120 
121 /// Returns true if the allocations of memref `type` generated from `allocOp`
122 /// can be lowered to SPIR-V.
123 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
124  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
126  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
127  return false;
128  } else if (isa<memref::AllocaOp>(allocOp)) {
129  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
130  if (!sc || sc.getValue() != spirv::StorageClass::Function)
131  return false;
132  } else {
133  return false;
134  }
135 
136  // Currently only support static shape and int or float or vector of int or
137  // float element type.
138  if (!type.hasStaticShape())
139  return false;
140 
141  Type elementType = type.getElementType();
142  if (auto vecType = dyn_cast<VectorType>(elementType))
143  elementType = vecType.getElementType();
144  return elementType.isIntOrFloat();
145 }
146 
147 /// Returns the scope to use for atomic operations use for emulating store
148 /// operations of unsupported integer bitwidths, based on the memref
149 /// type. Returns std::nullopt on failure.
150 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
151  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
152  switch (sc.getValue()) {
153  case spirv::StorageClass::StorageBuffer:
154  return spirv::Scope::Device;
155  case spirv::StorageClass::Workgroup:
156  return spirv::Scope::Workgroup;
157  default:
158  break;
159  }
160  return {};
161 }
162 
163 /// Casts the given `srcInt` into a boolean value.
164 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165  if (srcInt.getType().isInteger(1))
166  return srcInt;
167 
168  auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
169  return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Operation conversion
174 //===----------------------------------------------------------------------===//
175 
176 // Note that DRR cannot be used for the patterns in this file: we may need to
177 // convert type along the way, which requires ConversionPattern. DRR generates
178 // normal RewritePattern.
179 
180 namespace {
181 
182 /// Converts memref.alloca to SPIR-V Function variables.
183 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
184 public:
186 
187  LogicalResult
188  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
189  ConversionPatternRewriter &rewriter) const override;
190 };
191 
192 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
193 /// to Workgroup memory when the size is constant. Note that this pattern needs
194 /// to be applied in a pass that runs at least at spirv.module scope since it
195 /// wil ladd global variables into the spirv.module.
196 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
197 public:
199 
200  LogicalResult
201  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override;
203 };
204 
205 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
206 class AtomicRMWOpPattern final
207  : public OpConversionPattern<memref::AtomicRMWOp> {
208 public:
210 
211  LogicalResult
212  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override;
214 };
215 
216 /// Removed a deallocation if it is a supported allocation. Currently only
217 /// removes deallocation if the memory space is workgroup memory.
218 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
219 public:
221 
222  LogicalResult
223  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const override;
225 };
226 
227 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
228 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
229 public:
231 
232  LogicalResult
233  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
234  ConversionPatternRewriter &rewriter) const override;
235 };
236 
237 /// Converts memref.load to spirv.Load + spirv.AccessChain.
238 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
239 public:
241 
242  LogicalResult
243  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
244  ConversionPatternRewriter &rewriter) const override;
245 };
246 
247 /// Converts memref.store to spirv.Store on integers.
248 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249 public:
251 
252  LogicalResult
253  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
254  ConversionPatternRewriter &rewriter) const override;
255 };
256 
257 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
258 class MemorySpaceCastOpPattern final
259  : public OpConversionPattern<memref::MemorySpaceCastOp> {
260 public:
262 
263  LogicalResult
264  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const override;
266 };
267 
268 /// Converts memref.store to spirv.Store.
269 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
270 public:
272 
273  LogicalResult
274  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
275  ConversionPatternRewriter &rewriter) const override;
276 };
277 
278 class ReinterpretCastPattern final
279  : public OpConversionPattern<memref::ReinterpretCastOp> {
280 public:
282 
283  LogicalResult
284  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
285  ConversionPatternRewriter &rewriter) const override;
286 };
287 
288 class CastPattern final : public OpConversionPattern<memref::CastOp> {
289 public:
291 
292  LogicalResult
293  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
294  ConversionPatternRewriter &rewriter) const override {
295  Value src = adaptor.getSource();
296  Type srcType = src.getType();
297 
298  const TypeConverter *converter = getTypeConverter();
299  Type dstType = converter->convertType(op.getType());
300  if (srcType != dstType)
301  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
302  diag << "types doesn't match: " << srcType << " and " << dstType;
303  });
304 
305  rewriter.replaceOp(op, src);
306  return success();
307  }
308 };
309 
310 /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
311 class ExtractAlignedPointerAsIndexOpPattern final
312  : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
313 public:
315 
316  LogicalResult
317  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
318  OpAdaptor adaptor,
319  ConversionPatternRewriter &rewriter) const override;
320 };
321 } // namespace
322 
323 //===----------------------------------------------------------------------===//
324 // AllocaOp
325 //===----------------------------------------------------------------------===//
326 
327 LogicalResult
328 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
329  ConversionPatternRewriter &rewriter) const {
330  MemRefType allocType = allocaOp.getType();
331  if (!isAllocationSupported(allocaOp, allocType))
332  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
333 
334  // Get the SPIR-V type for the allocation.
335  Type spirvType = getTypeConverter()->convertType(allocType);
336  if (!spirvType)
337  return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
338 
339  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
340  spirv::StorageClass::Function,
341  /*initializer=*/nullptr);
342  return success();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // AllocOp
347 //===----------------------------------------------------------------------===//
348 
349 LogicalResult
350 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
351  ConversionPatternRewriter &rewriter) const {
352  MemRefType allocType = operation.getType();
353  if (!isAllocationSupported(operation, allocType))
354  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
355 
356  // Get the SPIR-V type for the allocation.
357  Type spirvType = getTypeConverter()->convertType(allocType);
358  if (!spirvType)
359  return rewriter.notifyMatchFailure(operation, "type conversion failed");
360 
361  // Insert spirv.GlobalVariable for this allocation.
362  Operation *parent =
363  SymbolTable::getNearestSymbolTable(operation->getParentOp());
364  if (!parent)
365  return failure();
366  Location loc = operation.getLoc();
367  spirv::GlobalVariableOp varOp;
368  {
369  OpBuilder::InsertionGuard guard(rewriter);
370  Block &entryBlock = *parent->getRegion(0).begin();
371  rewriter.setInsertionPointToStart(&entryBlock);
372  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
373  std::string varName =
374  std::string("__workgroup_mem__") +
375  std::to_string(std::distance(varOps.begin(), varOps.end()));
376  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
377  /*initializer=*/nullptr);
378  }
379 
380  // Get pointer to global variable at the current scope.
381  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
382  return success();
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // AllocOp
387 //===----------------------------------------------------------------------===//
388 
389 LogicalResult
390 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
391  OpAdaptor adaptor,
392  ConversionPatternRewriter &rewriter) const {
393  if (isa<FloatType>(atomicOp.getType()))
394  return rewriter.notifyMatchFailure(atomicOp,
395  "unimplemented floating-point case");
396 
397  auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
398  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
399  if (!scope)
400  return rewriter.notifyMatchFailure(atomicOp,
401  "unsupported memref memory space");
402 
403  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
404  Type resultType = typeConverter.convertType(atomicOp.getType());
405  if (!resultType)
406  return rewriter.notifyMatchFailure(atomicOp,
407  "failed to convert result type");
408 
409  auto loc = atomicOp.getLoc();
410  Value ptr =
411  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
412  adaptor.getIndices(), loc, rewriter);
413 
414  if (!ptr)
415  return failure();
416 
417 #define ATOMIC_CASE(kind, spirvOp) \
418  case arith::AtomicRMWKind::kind: \
419  rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
420  atomicOp, resultType, ptr, *scope, \
421  spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
422  break
423 
424  switch (atomicOp.getKind()) {
425  ATOMIC_CASE(addi, AtomicIAddOp);
426  ATOMIC_CASE(maxs, AtomicSMaxOp);
427  ATOMIC_CASE(maxu, AtomicUMaxOp);
428  ATOMIC_CASE(mins, AtomicSMinOp);
429  ATOMIC_CASE(minu, AtomicUMinOp);
430  ATOMIC_CASE(ori, AtomicOrOp);
431  ATOMIC_CASE(andi, AtomicAndOp);
432  default:
433  return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
434  }
435 
436 #undef ATOMIC_CASE
437 
438  return success();
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // DeallocOp
443 //===----------------------------------------------------------------------===//
444 
445 LogicalResult
446 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
447  OpAdaptor adaptor,
448  ConversionPatternRewriter &rewriter) const {
449  MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
450  if (!isAllocationSupported(operation, deallocType))
451  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
452  rewriter.eraseOp(operation);
453  return success();
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // LoadOp
458 //===----------------------------------------------------------------------===//
459 
461  spirv::MemoryAccessAttr memoryAccess;
462  IntegerAttr alignment;
463 };
464 
465 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
466 /// any.
467 static FailureOr<MemoryRequirements>
468 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
469  MLIRContext *ctx = accessedPtr.getContext();
470 
471  auto memoryAccess = spirv::MemoryAccess::None;
472  if (isNontemporal) {
473  memoryAccess = spirv::MemoryAccess::Nontemporal;
474  }
475 
476  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
477  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
478  if (memoryAccess == spirv::MemoryAccess::None) {
479  return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
480  }
481  return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
482  IntegerAttr{}};
483  }
484 
485  // PhysicalStorageBuffers require the `Aligned` attribute.
486  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
487  if (!pointeeType)
488  return failure();
489 
490  // For scalar types, the alignment is determined by their size.
491  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
492  if (!sizeInBytes.has_value())
493  return failure();
494 
495  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
496  auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
497  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
498  return MemoryRequirements{memAccessAttr, alignment};
499 }
500 
501 /// Given an accessed SPIR-V pointer and the original memref load/store
502 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
503 /// account the alignment attributes applied to the load/store op.
504 template <class LoadOrStoreOp>
505 static FailureOr<MemoryRequirements>
506 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
507  static_assert(
508  llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
509  "Must be called on either memref::LoadOp or memref::StoreOp");
510 
511  Operation *memrefAccessOp = loadOrStoreOp.getOperation();
512  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
513  spirv::attributeName<spirv::MemoryAccess>());
514  auto memrefAlignment =
515  memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
516  if (memrefMemAccess && memrefAlignment)
517  return MemoryRequirements{memrefMemAccess, memrefAlignment};
518 
519  return calculateMemoryRequirements(accessedPtr,
520  loadOrStoreOp.getNontemporal());
521 }
522 
523 LogicalResult
524 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
525  ConversionPatternRewriter &rewriter) const {
526  auto loc = loadOp.getLoc();
527  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
528  if (!memrefType.getElementType().isSignlessInteger())
529  return failure();
530 
531  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
532  Value accessChain =
533  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
534  adaptor.getIndices(), loc, rewriter);
535 
536  if (!accessChain)
537  return failure();
538 
539  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
540  bool isBool = srcBits == 1;
541  if (isBool)
542  srcBits = typeConverter.getOptions().boolNumBits;
543 
544  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
545  if (!pointerType)
546  return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
547 
548  Type pointeeType = pointerType.getPointeeType();
549  Type dstType;
550  if (typeConverter.allows(spirv::Capability::Kernel)) {
551  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
552  dstType = arrayType.getElementType();
553  else
554  dstType = pointeeType;
555  } else {
556  // For Vulkan we need to extract element from wrapping struct and array.
557  Type structElemType =
558  cast<spirv::StructType>(pointeeType).getElementType(0);
559  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
560  dstType = arrayType.getElementType();
561  else
562  dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
563  }
564  int dstBits = dstType.getIntOrFloatBitWidth();
565  assert(dstBits % srcBits == 0);
566 
567  // If the rewritten load op has the same bit width, use the loading value
568  // directly.
569  if (srcBits == dstBits) {
570  auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
571  if (failed(memoryRequirements))
572  return rewriter.notifyMatchFailure(
573  loadOp, "failed to determine memory requirements");
574 
575  auto [memoryAccess, alignment] = *memoryRequirements;
576  Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
577  memoryAccess, alignment);
578  if (isBool)
579  loadVal = castIntNToBool(loc, loadVal, rewriter);
580  rewriter.replaceOp(loadOp, loadVal);
581  return success();
582  }
583 
584  // Bitcasting is currently unsupported for Kernel capability /
585  // spirv.PtrAccessChain.
586  if (typeConverter.allows(spirv::Capability::Kernel))
587  return failure();
588 
589  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
590  if (!accessChainOp)
591  return failure();
592 
593  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
594  // still returns a linearized accessing. If the accessing is not linearized,
595  // there will be offset issues.
596  assert(accessChainOp.getIndices().size() == 2);
597  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
598  srcBits, dstBits, rewriter);
599  auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
600  if (failed(memoryRequirements))
601  return rewriter.notifyMatchFailure(
602  loadOp, "failed to determine memory requirements");
603 
604  auto [memoryAccess, alignment] = *memoryRequirements;
605  Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
606  memoryAccess, alignment);
607 
608  // Shift the bits to the rightmost.
609  // ____XXXX________ -> ____________XXXX
610  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
611  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
612  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
613  loc, spvLoadOp.getType(), spvLoadOp, offset);
614 
615  // Apply the mask to extract corresponding bits.
616  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
617  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
618  result =
619  rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
620 
621  // Apply sign extension on the loading value unconditionally. The signedness
622  // semantic is carried in the operator itself, we relies other pattern to
623  // handle the casting.
624  IntegerAttr shiftValueAttr =
625  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
626  Value shiftValue =
627  rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
628  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
629  result, shiftValue);
630  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
631  loc, dstType, result, shiftValue);
632 
633  rewriter.replaceOp(loadOp, result);
634 
635  assert(accessChainOp.use_empty());
636  rewriter.eraseOp(accessChainOp);
637 
638  return success();
639 }
640 
641 LogicalResult
642 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
643  ConversionPatternRewriter &rewriter) const {
644  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
645  if (memrefType.getElementType().isSignlessInteger())
646  return failure();
647  Value loadPtr = spirv::getElementPtr(
648  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
649  adaptor.getIndices(), loadOp.getLoc(), rewriter);
650 
651  if (!loadPtr)
652  return failure();
653 
654  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
655  if (failed(memoryRequirements))
656  return rewriter.notifyMatchFailure(
657  loadOp, "failed to determine memory requirements");
658 
659  auto [memoryAccess, alignment] = *memoryRequirements;
660  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
661  alignment);
662  return success();
663 }
664 
665 LogicalResult
666 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
667  ConversionPatternRewriter &rewriter) const {
668  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
669  if (!memrefType.getElementType().isSignlessInteger())
670  return rewriter.notifyMatchFailure(storeOp,
671  "element type is not a signless int");
672 
673  auto loc = storeOp.getLoc();
674  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
675  Value accessChain =
676  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
677  adaptor.getIndices(), loc, rewriter);
678 
679  if (!accessChain)
680  return rewriter.notifyMatchFailure(
681  storeOp, "failed to convert element pointer type");
682 
683  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
684 
685  bool isBool = srcBits == 1;
686  if (isBool)
687  srcBits = typeConverter.getOptions().boolNumBits;
688 
689  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
690  if (!pointerType)
691  return rewriter.notifyMatchFailure(storeOp,
692  "failed to convert memref type");
693 
694  Type pointeeType = pointerType.getPointeeType();
695  IntegerType dstType;
696  if (typeConverter.allows(spirv::Capability::Kernel)) {
697  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
698  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
699  else
700  dstType = dyn_cast<IntegerType>(pointeeType);
701  } else {
702  // For Vulkan we need to extract element from wrapping struct and array.
703  Type structElemType =
704  cast<spirv::StructType>(pointeeType).getElementType(0);
705  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
706  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
707  else
708  dstType = dyn_cast<IntegerType>(
709  cast<spirv::RuntimeArrayType>(structElemType).getElementType());
710  }
711 
712  if (!dstType)
713  return rewriter.notifyMatchFailure(
714  storeOp, "failed to determine destination element type");
715 
716  int dstBits = static_cast<int>(dstType.getWidth());
717  assert(dstBits % srcBits == 0);
718 
719  if (srcBits == dstBits) {
720  auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
721  if (failed(memoryRequirements))
722  return rewriter.notifyMatchFailure(
723  storeOp, "failed to determine memory requirements");
724 
725  auto [memoryAccess, alignment] = *memoryRequirements;
726  Value storeVal = adaptor.getValue();
727  if (isBool)
728  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
729  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
730  memoryAccess, alignment);
731  return success();
732  }
733 
734  // Bitcasting is currently unsupported for Kernel capability /
735  // spirv.PtrAccessChain.
736  if (typeConverter.allows(spirv::Capability::Kernel))
737  return failure();
738 
739  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
740  if (!accessChainOp)
741  return failure();
742 
743  // Since there are multiple threads in the processing, the emulation will be
744  // done with atomic operations. E.g., if the stored value is i8, rewrite the
745  // StoreOp to:
746  // 1) load a 32-bit integer
747  // 2) clear 8 bits in the loaded value
748  // 3) set 8 bits in the loaded value
749  // 4) store 32-bit value back
750  //
751  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
752  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
753  // step.
754  assert(accessChainOp.getIndices().size() == 2);
755  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
756  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
757 
758  // Create a mask to clear the destination. E.g., if it is the second i8 in
759  // i32, 0xFFFF00FF is created.
760  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
761  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
762  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
763  loc, dstType, mask, offset);
764  clearBitsMask =
765  rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
766 
767  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
768  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
769  srcBits, dstBits, rewriter);
770  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
771  if (!scope)
772  return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
773 
774  Value result = rewriter.create<spirv::AtomicAndOp>(
775  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
776  clearBitsMask);
777  result = rewriter.create<spirv::AtomicOrOp>(
778  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
779  storeVal);
780 
781  // The AtomicOrOp has no side effect. Since it is already inserted, we can
782  // just remove the original StoreOp. Note that rewriter.replaceOp()
783  // doesn't work because it only accepts that the numbers of result are the
784  // same.
785  rewriter.eraseOp(storeOp);
786 
787  assert(accessChainOp.use_empty());
788  rewriter.eraseOp(accessChainOp);
789 
790  return success();
791 }
792 
793 //===----------------------------------------------------------------------===//
794 // MemorySpaceCastOp
795 //===----------------------------------------------------------------------===//
796 
797 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
798  memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
799  ConversionPatternRewriter &rewriter) const {
800  Location loc = addrCastOp.getLoc();
801  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
802  if (!typeConverter.allows(spirv::Capability::Kernel))
803  return rewriter.notifyMatchFailure(
804  loc, "address space casts require kernel capability");
805 
806  auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
807  if (!sourceType)
808  return rewriter.notifyMatchFailure(
809  loc, "SPIR-V lowering requires ranked memref types");
810  auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
811 
812  auto sourceStorageClassAttr =
813  dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
814  if (!sourceStorageClassAttr)
815  return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
816  diag << "source address space " << sourceType.getMemorySpace()
817  << " must be a SPIR-V storage class";
818  });
819  auto resultStorageClassAttr =
820  dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
821  if (!resultStorageClassAttr)
822  return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
823  diag << "result address space " << resultType.getMemorySpace()
824  << " must be a SPIR-V storage class";
825  });
826 
827  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
828  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
829 
830  Value result = adaptor.getSource();
831  Type resultPtrType = typeConverter.convertType(resultType);
832  if (!resultPtrType)
833  return rewriter.notifyMatchFailure(addrCastOp,
834  "failed to convert memref type");
835 
836  Type genericPtrType = resultPtrType;
837  // SPIR-V doesn't have a general address space cast operation. Instead, it has
838  // conversions to and from generic pointers. To implement the general case,
839  // we use specific-to-generic conversions when the source class is not
840  // generic. Then when the result storage class is not generic, we convert the
841  // generic pointer (either the input on ar intermediate result) to that
842  // class. This also means that we'll need the intermediate generic pointer
843  // type if neither the source or destination have it.
844  if (sourceSc != spirv::StorageClass::Generic &&
845  resultSc != spirv::StorageClass::Generic) {
846  Type intermediateType =
847  MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
848  sourceType.getLayout(),
849  rewriter.getAttr<spirv::StorageClassAttr>(
850  spirv::StorageClass::Generic));
851  genericPtrType = typeConverter.convertType(intermediateType);
852  }
853  if (sourceSc != spirv::StorageClass::Generic) {
854  result =
855  rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
856  }
857  if (resultSc != spirv::StorageClass::Generic) {
858  result =
859  rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
860  }
861  rewriter.replaceOp(addrCastOp, result);
862  return success();
863 }
864 
865 LogicalResult
866 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
867  ConversionPatternRewriter &rewriter) const {
868  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
869  if (memrefType.getElementType().isSignlessInteger())
870  return rewriter.notifyMatchFailure(storeOp, "signless int");
871  auto storePtr = spirv::getElementPtr(
872  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
873  adaptor.getIndices(), storeOp.getLoc(), rewriter);
874 
875  if (!storePtr)
876  return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
877 
878  auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
879  if (failed(memoryRequirements))
880  return rewriter.notifyMatchFailure(
881  storeOp, "failed to determine memory requirements");
882 
883  auto [memoryAccess, alignment] = *memoryRequirements;
884  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
885  storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
886  return success();
887 }
888 
889 LogicalResult ReinterpretCastPattern::matchAndRewrite(
890  memref::ReinterpretCastOp op, OpAdaptor adaptor,
891  ConversionPatternRewriter &rewriter) const {
892  Value src = adaptor.getSource();
893  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
894 
895  if (!srcType)
896  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
897  diag << "invalid src type " << src.getType();
898  });
899 
900  const TypeConverter *converter = getTypeConverter();
901 
902  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
903  if (dstType != srcType)
904  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
905  diag << "invalid dst type " << op.getType();
906  });
907 
908  OpFoldResult offset =
909  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
910  .front();
911  if (isZeroInteger(offset)) {
912  rewriter.replaceOp(op, src);
913  return success();
914  }
915 
916  Type intType = converter->convertType(rewriter.getIndexType());
917  if (!intType)
918  return rewriter.notifyMatchFailure(op, "failed to convert index type");
919 
920  Location loc = op.getLoc();
921  auto offsetValue = [&]() -> Value {
922  if (auto val = dyn_cast<Value>(offset))
923  return val;
924 
925  int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
926  Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
927  return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
928  }();
929 
930  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
931  op, src, offsetValue, std::nullopt);
932  return success();
933 }
934 
935 //===----------------------------------------------------------------------===//
936 // ExtractAlignedPointerAsIndexOp
937 //===----------------------------------------------------------------------===//
938 
939 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
940  memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
941  ConversionPatternRewriter &rewriter) const {
942  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
943  Type indexType = typeConverter.getIndexType();
944  rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
945  adaptor.getSource());
946  return success();
947 }
948 
949 //===----------------------------------------------------------------------===//
950 // Pattern population
951 //===----------------------------------------------------------------------===//
952 
953 namespace mlir {
956  patterns
957  .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
958  DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959  MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960  CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961  typeConverter, patterns.getContext());
962 }
963 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
@ None
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
IndexType getIndexType()
Definition: Builders.cpp:53
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess