MLIR  19.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
25 #include "llvm/Support/Debug.h"
26 #include <cassert>
27 #include <optional>
28 
29 #define DEBUG_TYPE "memref-to-spirv-pattern"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Returns the offset of the value in `targetBits` representation.
38 ///
39 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
40 /// It's assumed to be non-negative.
41 ///
42 /// When accessing an element in the array treating as having elements of
43 /// `targetBits`, multiple values are loaded in the same time. The method
44 /// returns the offset where the `srcIdx` locates in the value. For example, if
45 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
46 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
47 /// element has 8 bits.
48 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
49  int targetBits, OpBuilder &builder) {
50  assert(targetBits % sourceBits == 0);
51  Type type = srcIdx.getType();
52  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
53  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
54  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
55  auto srcBitsValue =
56  builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
57  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
58  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
59 }
60 
61 /// Returns an adjusted spirv::AccessChainOp. Based on the
62 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
63 /// supported. During conversion if a memref of an unsupported type is used,
64 /// load/stores to this memref need to be modified to use a supported higher
65 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
66 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
67 /// the bits needed. The extraction of the actual bits needed are handled
68 /// separately. Note that this only works for a 1-D tensor.
69 static Value
71  spirv::AccessChainOp op, int sourceBits,
72  int targetBits, OpBuilder &builder) {
73  assert(targetBits % sourceBits == 0);
74  const auto loc = op.getLoc();
75  Value lastDim = op->getOperand(op.getNumOperands() - 1);
76  Type type = lastDim.getType();
77  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
78  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
79  auto indices = llvm::to_vector<4>(op.getIndices());
80  // There are two elements if this is a 1-D tensor.
81  assert(indices.size() == 2);
82  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
83  Type t = typeConverter.convertType(op.getComponentPtr().getType());
84  return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
85 }
86 
87 /// Casts the given `srcBool` into an integer of `dstType`.
88 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89  OpBuilder &builder) {
90  assert(srcBool.getType().isInteger(1));
91  if (dstType.isInteger(1))
92  return srcBool;
93  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96  zero);
97 }
98 
99 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100 /// to the type destination type, and masked.
101 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102  OpBuilder &builder) {
103  IntegerType dstType = cast<IntegerType>(mask.getType());
104  int targetBits = static_cast<int>(dstType.getWidth());
105  int valueBits = value.getType().getIntOrFloatBitWidth();
106  assert(valueBits <= targetBits);
107 
108  if (valueBits == 1) {
109  value = castBoolToIntN(loc, value, dstType, builder);
110  } else {
111  if (valueBits < targetBits) {
112  value = builder.create<spirv::UConvertOp>(
113  loc, builder.getIntegerType(targetBits), value);
114  }
115 
116  value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117  }
118  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119  value, offset);
120 }
121 
122 /// Returns true if the allocations of memref `type` generated from `allocOp`
123 /// can be lowered to SPIR-V.
124 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128  return false;
129  } else if (isa<memref::AllocaOp>(allocOp)) {
130  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131  if (!sc || sc.getValue() != spirv::StorageClass::Function)
132  return false;
133  } else {
134  return false;
135  }
136 
137  // Currently only support static shape and int or float or vector of int or
138  // float element type.
139  if (!type.hasStaticShape())
140  return false;
141 
142  Type elementType = type.getElementType();
143  if (auto vecType = dyn_cast<VectorType>(elementType))
144  elementType = vecType.getElementType();
145  return elementType.isIntOrFloat();
146 }
147 
148 /// Returns the scope to use for atomic operations use for emulating store
149 /// operations of unsupported integer bitwidths, based on the memref
150 /// type. Returns std::nullopt on failure.
151 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
152  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153  switch (sc.getValue()) {
154  case spirv::StorageClass::StorageBuffer:
155  return spirv::Scope::Device;
156  case spirv::StorageClass::Workgroup:
157  return spirv::Scope::Workgroup;
158  default:
159  break;
160  }
161  return {};
162 }
163 
164 /// Casts the given `srcInt` into a boolean value.
165 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
166  if (srcInt.getType().isInteger(1))
167  return srcInt;
168 
169  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
170  return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Operation conversion
175 //===----------------------------------------------------------------------===//
176 
177 // Note that DRR cannot be used for the patterns in this file: we may need to
178 // convert type along the way, which requires ConversionPattern. DRR generates
179 // normal RewritePattern.
180 
181 namespace {
182 
183 /// Converts memref.alloca to SPIR-V Function variables.
184 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
185 public:
187 
189  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
190  ConversionPatternRewriter &rewriter) const override;
191 };
192 
193 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
194 /// to Workgroup memory when the size is constant. Note that this pattern needs
195 /// to be applied in a pass that runs at least at spirv.module scope since it
196 /// wil ladd global variables into the spirv.module.
197 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
198 public:
200 
202  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override;
204 };
205 
206 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
207 class AtomicRMWOpPattern final
208  : public OpConversionPattern<memref::AtomicRMWOp> {
209 public:
211 
213  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
214  ConversionPatternRewriter &rewriter) const override;
215 };
216 
217 /// Removed a deallocation if it is a supported allocation. Currently only
218 /// removes deallocation if the memory space is workgroup memory.
219 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
220 public:
222 
224  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
225  ConversionPatternRewriter &rewriter) const override;
226 };
227 
228 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
229 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
230 public:
232 
234  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
235  ConversionPatternRewriter &rewriter) const override;
236 };
237 
238 /// Converts memref.load to spirv.Load + spirv.AccessChain.
239 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
240 public:
242 
244  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
245  ConversionPatternRewriter &rewriter) const override;
246 };
247 
248 /// Converts memref.store to spirv.Store on integers.
249 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
250 public:
252 
254  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
255  ConversionPatternRewriter &rewriter) const override;
256 };
257 
258 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
259 class MemorySpaceCastOpPattern final
260  : public OpConversionPattern<memref::MemorySpaceCastOp> {
261 public:
263 
265  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
266  ConversionPatternRewriter &rewriter) const override;
267 };
268 
269 /// Converts memref.store to spirv.Store.
270 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
271 public:
273 
275  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
276  ConversionPatternRewriter &rewriter) const override;
277 };
278 
279 class ReinterpretCastPattern final
280  : public OpConversionPattern<memref::ReinterpretCastOp> {
281 public:
283 
285  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const override;
287 };
288 
289 class CastPattern final : public OpConversionPattern<memref::CastOp> {
290 public:
292 
294  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
295  ConversionPatternRewriter &rewriter) const override {
296  Value src = adaptor.getSource();
297  Type srcType = src.getType();
298 
299  const TypeConverter *converter = getTypeConverter();
300  Type dstType = converter->convertType(op.getType());
301  if (srcType != dstType)
302  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
303  diag << "types doesn't match: " << srcType << " and " << dstType;
304  });
305 
306  rewriter.replaceOp(op, src);
307  return success();
308  }
309 };
310 
311 } // namespace
312 
313 //===----------------------------------------------------------------------===//
314 // AllocaOp
315 //===----------------------------------------------------------------------===//
316 
318 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
319  ConversionPatternRewriter &rewriter) const {
320  MemRefType allocType = allocaOp.getType();
321  if (!isAllocationSupported(allocaOp, allocType))
322  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
323 
324  // Get the SPIR-V type for the allocation.
325  Type spirvType = getTypeConverter()->convertType(allocType);
326  if (!spirvType)
327  return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
328 
329  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
330  spirv::StorageClass::Function,
331  /*initializer=*/nullptr);
332  return success();
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // AllocOp
337 //===----------------------------------------------------------------------===//
338 
340 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
341  ConversionPatternRewriter &rewriter) const {
342  MemRefType allocType = operation.getType();
343  if (!isAllocationSupported(operation, allocType))
344  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
345 
346  // Get the SPIR-V type for the allocation.
347  Type spirvType = getTypeConverter()->convertType(allocType);
348  if (!spirvType)
349  return rewriter.notifyMatchFailure(operation, "type conversion failed");
350 
351  // Insert spirv.GlobalVariable for this allocation.
352  Operation *parent =
353  SymbolTable::getNearestSymbolTable(operation->getParentOp());
354  if (!parent)
355  return failure();
356  Location loc = operation.getLoc();
357  spirv::GlobalVariableOp varOp;
358  {
359  OpBuilder::InsertionGuard guard(rewriter);
360  Block &entryBlock = *parent->getRegion(0).begin();
361  rewriter.setInsertionPointToStart(&entryBlock);
362  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
363  std::string varName =
364  std::string("__workgroup_mem__") +
365  std::to_string(std::distance(varOps.begin(), varOps.end()));
366  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
367  /*initializer=*/nullptr);
368  }
369 
370  // Get pointer to global variable at the current scope.
371  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
372  return success();
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // AllocOp
377 //===----------------------------------------------------------------------===//
378 
380 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
381  OpAdaptor adaptor,
382  ConversionPatternRewriter &rewriter) const {
383  if (isa<FloatType>(atomicOp.getType()))
384  return rewriter.notifyMatchFailure(atomicOp,
385  "unimplemented floating-point case");
386 
387  auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
388  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
389  if (!scope)
390  return rewriter.notifyMatchFailure(atomicOp,
391  "unsupported memref memory space");
392 
393  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
394  Type resultType = typeConverter.convertType(atomicOp.getType());
395  if (!resultType)
396  return rewriter.notifyMatchFailure(atomicOp,
397  "failed to convert result type");
398 
399  auto loc = atomicOp.getLoc();
400  Value ptr =
401  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
402  adaptor.getIndices(), loc, rewriter);
403 
404  if (!ptr)
405  return failure();
406 
407 #define ATOMIC_CASE(kind, spirvOp) \
408  case arith::AtomicRMWKind::kind: \
409  rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
410  atomicOp, resultType, ptr, *scope, \
411  spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
412  break
413 
414  switch (atomicOp.getKind()) {
415  ATOMIC_CASE(addi, AtomicIAddOp);
416  ATOMIC_CASE(maxs, AtomicSMaxOp);
417  ATOMIC_CASE(maxu, AtomicUMaxOp);
418  ATOMIC_CASE(mins, AtomicSMinOp);
419  ATOMIC_CASE(minu, AtomicUMinOp);
420  ATOMIC_CASE(ori, AtomicOrOp);
421  ATOMIC_CASE(andi, AtomicAndOp);
422  default:
423  return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
424  }
425 
426 #undef ATOMIC_CASE
427 
428  return success();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // DeallocOp
433 //===----------------------------------------------------------------------===//
434 
436 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
437  OpAdaptor adaptor,
438  ConversionPatternRewriter &rewriter) const {
439  MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
440  if (!isAllocationSupported(operation, deallocType))
441  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
442  rewriter.eraseOp(operation);
443  return success();
444 }
445 
446 //===----------------------------------------------------------------------===//
447 // LoadOp
448 //===----------------------------------------------------------------------===//
449 
451  spirv::MemoryAccessAttr memoryAccess;
452  IntegerAttr alignment;
453 };
454 
455 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
456 /// any.
458 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
459  MLIRContext *ctx = accessedPtr.getContext();
460 
461  auto memoryAccess = spirv::MemoryAccess::None;
462  if (isNontemporal) {
463  memoryAccess = spirv::MemoryAccess::Nontemporal;
464  }
465 
466  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
467  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
468  if (memoryAccess == spirv::MemoryAccess::None) {
469  return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
470  }
471  return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
472  IntegerAttr{}};
473  }
474 
475  // PhysicalStorageBuffers require the `Aligned` attribute.
476  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
477  if (!pointeeType)
478  return failure();
479 
480  // For scalar types, the alignment is determined by their size.
481  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
482  if (!sizeInBytes.has_value())
483  return failure();
484 
485  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
486  auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
487  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
488  return MemoryRequirements{memAccessAttr, alignment};
489 }
490 
491 /// Given an accessed SPIR-V pointer and the original memref load/store
492 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
493 /// account the alignment attributes applied to the load/store op.
494 template <class LoadOrStoreOp>
496 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
497  static_assert(
498  llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
499  "Must be called on either memref::LoadOp or memref::StoreOp");
500 
501  Operation *memrefAccessOp = loadOrStoreOp.getOperation();
502  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
503  spirv::attributeName<spirv::MemoryAccess>());
504  auto memrefAlignment =
505  memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
506  if (memrefMemAccess && memrefAlignment)
507  return MemoryRequirements{memrefMemAccess, memrefAlignment};
508 
509  return calculateMemoryRequirements(accessedPtr,
510  loadOrStoreOp.getNontemporal());
511 }
512 
514 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
515  ConversionPatternRewriter &rewriter) const {
516  auto loc = loadOp.getLoc();
517  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
518  if (!memrefType.getElementType().isSignlessInteger())
519  return failure();
520 
521  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
522  Value accessChain =
523  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
524  adaptor.getIndices(), loc, rewriter);
525 
526  if (!accessChain)
527  return failure();
528 
529  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
530  bool isBool = srcBits == 1;
531  if (isBool)
532  srcBits = typeConverter.getOptions().boolNumBits;
533 
534  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
535  if (!pointerType)
536  return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
537 
538  Type pointeeType = pointerType.getPointeeType();
539  Type dstType;
540  if (typeConverter.allows(spirv::Capability::Kernel)) {
541  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
542  dstType = arrayType.getElementType();
543  else
544  dstType = pointeeType;
545  } else {
546  // For Vulkan we need to extract element from wrapping struct and array.
547  Type structElemType =
548  cast<spirv::StructType>(pointeeType).getElementType(0);
549  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
550  dstType = arrayType.getElementType();
551  else
552  dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
553  }
554  int dstBits = dstType.getIntOrFloatBitWidth();
555  assert(dstBits % srcBits == 0);
556 
557  // If the rewritten load op has the same bit width, use the loading value
558  // directly.
559  if (srcBits == dstBits) {
560  auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
561  if (failed(memoryRequirements))
562  return rewriter.notifyMatchFailure(
563  loadOp, "failed to determine memory requirements");
564 
565  auto [memoryAccess, alignment] = *memoryRequirements;
566  Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
567  memoryAccess, alignment);
568  if (isBool)
569  loadVal = castIntNToBool(loc, loadVal, rewriter);
570  rewriter.replaceOp(loadOp, loadVal);
571  return success();
572  }
573 
574  // Bitcasting is currently unsupported for Kernel capability /
575  // spirv.PtrAccessChain.
576  if (typeConverter.allows(spirv::Capability::Kernel))
577  return failure();
578 
579  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
580  if (!accessChainOp)
581  return failure();
582 
583  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
584  // still returns a linearized accessing. If the accessing is not linearized,
585  // there will be offset issues.
586  assert(accessChainOp.getIndices().size() == 2);
587  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
588  srcBits, dstBits, rewriter);
589  auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
590  if (failed(memoryRequirements))
591  return rewriter.notifyMatchFailure(
592  loadOp, "failed to determine memory requirements");
593 
594  auto [memoryAccess, alignment] = *memoryRequirements;
595  Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
596  memoryAccess, alignment);
597 
598  // Shift the bits to the rightmost.
599  // ____XXXX________ -> ____________XXXX
600  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
601  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
602  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
603  loc, spvLoadOp.getType(), spvLoadOp, offset);
604 
605  // Apply the mask to extract corresponding bits.
606  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
607  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
608  result =
609  rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
610 
611  // Apply sign extension on the loading value unconditionally. The signedness
612  // semantic is carried in the operator itself, we relies other pattern to
613  // handle the casting.
614  IntegerAttr shiftValueAttr =
615  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
616  Value shiftValue =
617  rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
618  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
619  result, shiftValue);
620  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
621  loc, dstType, result, shiftValue);
622 
623  rewriter.replaceOp(loadOp, result);
624 
625  assert(accessChainOp.use_empty());
626  rewriter.eraseOp(accessChainOp);
627 
628  return success();
629 }
630 
632 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
633  ConversionPatternRewriter &rewriter) const {
634  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
635  if (memrefType.getElementType().isSignlessInteger())
636  return failure();
637  Value loadPtr = spirv::getElementPtr(
638  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
639  adaptor.getIndices(), loadOp.getLoc(), rewriter);
640 
641  if (!loadPtr)
642  return failure();
643 
644  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
645  if (failed(memoryRequirements))
646  return rewriter.notifyMatchFailure(
647  loadOp, "failed to determine memory requirements");
648 
649  auto [memoryAccess, alignment] = *memoryRequirements;
650  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
651  alignment);
652  return success();
653 }
654 
656 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
657  ConversionPatternRewriter &rewriter) const {
658  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
659  if (!memrefType.getElementType().isSignlessInteger())
660  return rewriter.notifyMatchFailure(storeOp,
661  "element type is not a signless int");
662 
663  auto loc = storeOp.getLoc();
664  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
665  Value accessChain =
666  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
667  adaptor.getIndices(), loc, rewriter);
668 
669  if (!accessChain)
670  return rewriter.notifyMatchFailure(
671  storeOp, "failed to convert element pointer type");
672 
673  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
674 
675  bool isBool = srcBits == 1;
676  if (isBool)
677  srcBits = typeConverter.getOptions().boolNumBits;
678 
679  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
680  if (!pointerType)
681  return rewriter.notifyMatchFailure(storeOp,
682  "failed to convert memref type");
683 
684  Type pointeeType = pointerType.getPointeeType();
685  IntegerType dstType;
686  if (typeConverter.allows(spirv::Capability::Kernel)) {
687  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
688  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
689  else
690  dstType = dyn_cast<IntegerType>(pointeeType);
691  } else {
692  // For Vulkan we need to extract element from wrapping struct and array.
693  Type structElemType =
694  cast<spirv::StructType>(pointeeType).getElementType(0);
695  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
696  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
697  else
698  dstType = dyn_cast<IntegerType>(
699  cast<spirv::RuntimeArrayType>(structElemType).getElementType());
700  }
701 
702  if (!dstType)
703  return rewriter.notifyMatchFailure(
704  storeOp, "failed to determine destination element type");
705 
706  int dstBits = static_cast<int>(dstType.getWidth());
707  assert(dstBits % srcBits == 0);
708 
709  if (srcBits == dstBits) {
710  auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
711  if (failed(memoryRequirements))
712  return rewriter.notifyMatchFailure(
713  storeOp, "failed to determine memory requirements");
714 
715  auto [memoryAccess, alignment] = *memoryRequirements;
716  Value storeVal = adaptor.getValue();
717  if (isBool)
718  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
719  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
720  memoryAccess, alignment);
721  return success();
722  }
723 
724  // Bitcasting is currently unsupported for Kernel capability /
725  // spirv.PtrAccessChain.
726  if (typeConverter.allows(spirv::Capability::Kernel))
727  return failure();
728 
729  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
730  if (!accessChainOp)
731  return failure();
732 
733  // Since there are multiple threads in the processing, the emulation will be
734  // done with atomic operations. E.g., if the stored value is i8, rewrite the
735  // StoreOp to:
736  // 1) load a 32-bit integer
737  // 2) clear 8 bits in the loaded value
738  // 3) set 8 bits in the loaded value
739  // 4) store 32-bit value back
740  //
741  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
742  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
743  // step.
744  assert(accessChainOp.getIndices().size() == 2);
745  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
746  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
747 
748  // Create a mask to clear the destination. E.g., if it is the second i8 in
749  // i32, 0xFFFF00FF is created.
750  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
751  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
752  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
753  loc, dstType, mask, offset);
754  clearBitsMask =
755  rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
756 
757  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
758  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
759  srcBits, dstBits, rewriter);
760  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
761  if (!scope)
762  return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
763 
764  Value result = rewriter.create<spirv::AtomicAndOp>(
765  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
766  clearBitsMask);
767  result = rewriter.create<spirv::AtomicOrOp>(
768  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
769  storeVal);
770 
771  // The AtomicOrOp has no side effect. Since it is already inserted, we can
772  // just remove the original StoreOp. Note that rewriter.replaceOp()
773  // doesn't work because it only accepts that the numbers of result are the
774  // same.
775  rewriter.eraseOp(storeOp);
776 
777  assert(accessChainOp.use_empty());
778  rewriter.eraseOp(accessChainOp);
779 
780  return success();
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // MemorySpaceCastOp
785 //===----------------------------------------------------------------------===//
786 
787 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
788  memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
789  ConversionPatternRewriter &rewriter) const {
790  Location loc = addrCastOp.getLoc();
791  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
792  if (!typeConverter.allows(spirv::Capability::Kernel))
793  return rewriter.notifyMatchFailure(
794  loc, "address space casts require kernel capability");
795 
796  auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
797  if (!sourceType)
798  return rewriter.notifyMatchFailure(
799  loc, "SPIR-V lowering requires ranked memref types");
800  auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
801 
802  auto sourceStorageClassAttr =
803  dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
804  if (!sourceStorageClassAttr)
805  return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
806  diag << "source address space " << sourceType.getMemorySpace()
807  << " must be a SPIR-V storage class";
808  });
809  auto resultStorageClassAttr =
810  dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
811  if (!resultStorageClassAttr)
812  return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
813  diag << "result address space " << resultType.getMemorySpace()
814  << " must be a SPIR-V storage class";
815  });
816 
817  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
818  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
819 
820  Value result = adaptor.getSource();
821  Type resultPtrType = typeConverter.convertType(resultType);
822  if (!resultPtrType)
823  return rewriter.notifyMatchFailure(addrCastOp,
824  "failed to convert memref type");
825 
826  Type genericPtrType = resultPtrType;
827  // SPIR-V doesn't have a general address space cast operation. Instead, it has
828  // conversions to and from generic pointers. To implement the general case,
829  // we use specific-to-generic conversions when the source class is not
830  // generic. Then when the result storage class is not generic, we convert the
831  // generic pointer (either the input on ar intermediate result) to that
832  // class. This also means that we'll need the intermediate generic pointer
833  // type if neither the source or destination have it.
834  if (sourceSc != spirv::StorageClass::Generic &&
835  resultSc != spirv::StorageClass::Generic) {
836  Type intermediateType =
837  MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
838  sourceType.getLayout(),
839  rewriter.getAttr<spirv::StorageClassAttr>(
840  spirv::StorageClass::Generic));
841  genericPtrType = typeConverter.convertType(intermediateType);
842  }
843  if (sourceSc != spirv::StorageClass::Generic) {
844  result =
845  rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
846  }
847  if (resultSc != spirv::StorageClass::Generic) {
848  result =
849  rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
850  }
851  rewriter.replaceOp(addrCastOp, result);
852  return success();
853 }
854 
856 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
857  ConversionPatternRewriter &rewriter) const {
858  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
859  if (memrefType.getElementType().isSignlessInteger())
860  return rewriter.notifyMatchFailure(storeOp, "signless int");
861  auto storePtr = spirv::getElementPtr(
862  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
863  adaptor.getIndices(), storeOp.getLoc(), rewriter);
864 
865  if (!storePtr)
866  return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
867 
868  auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
869  if (failed(memoryRequirements))
870  return rewriter.notifyMatchFailure(
871  storeOp, "failed to determine memory requirements");
872 
873  auto [memoryAccess, alignment] = *memoryRequirements;
874  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
875  storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
876  return success();
877 }
878 
879 LogicalResult ReinterpretCastPattern::matchAndRewrite(
880  memref::ReinterpretCastOp op, OpAdaptor adaptor,
881  ConversionPatternRewriter &rewriter) const {
882  Value src = adaptor.getSource();
883  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
884 
885  if (!srcType)
886  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
887  diag << "invalid src type " << src.getType();
888  });
889 
890  const TypeConverter *converter = getTypeConverter();
891 
892  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
893  if (dstType != srcType)
894  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
895  diag << "invalid dst type " << op.getType();
896  });
897 
898  OpFoldResult offset =
899  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
900  .front();
901  if (isConstantIntValue(offset, 0)) {
902  rewriter.replaceOp(op, src);
903  return success();
904  }
905 
906  Type intType = converter->convertType(rewriter.getIndexType());
907  if (!intType)
908  return rewriter.notifyMatchFailure(op, "failed to convert index type");
909 
910  Location loc = op.getLoc();
911  auto offsetValue = [&]() -> Value {
912  if (auto val = dyn_cast<Value>(offset))
913  return val;
914 
915  int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
916  Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
917  return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
918  }();
919 
920  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
921  op, src, offsetValue, std::nullopt);
922  return success();
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // Pattern population
927 //===----------------------------------------------------------------------===//
928 
929 namespace mlir {
931  RewritePatternSet &patterns) {
932  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
933  DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
934  LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
935  ReinterpretCastPattern, CastPattern>(typeConverter,
936  patterns.getContext());
937 }
938 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
@ None
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:190
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IndexType getIndexType()
Definition: Builders.cpp:71
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
iterator begin()
Definition: Region.h:55
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:128
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
spirv::MemoryAccessAttr memoryAccess
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26