MLIR  18.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/Support/Debug.h"
21 #include <optional>
22 
23 #define DEBUG_TYPE "memref-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Returns the offset of the value in `targetBits` representation.
32 ///
33 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
34 /// It's assumed to be non-negative.
35 ///
36 /// When accessing an element in the array treating as having elements of
37 /// `targetBits`, multiple values are loaded in the same time. The method
38 /// returns the offset where the `srcIdx` locates in the value. For example, if
39 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
40 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
41 /// element has 8 bits.
42 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
43  int targetBits, OpBuilder &builder) {
44  assert(targetBits % sourceBits == 0);
45  Type type = srcIdx.getType();
46  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
47  auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
48  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
49  auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
50  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
51  return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
52 }
53 
54 /// Returns an adjusted spirv::AccessChainOp. Based on the
55 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
56 /// supported. During conversion if a memref of an unsupported type is used,
57 /// load/stores to this memref need to be modified to use a supported higher
58 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
59 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
60 /// the bits needed. The extraction of the actual bits needed are handled
61 /// separately. Note that this only works for a 1-D tensor.
62 static Value
64  spirv::AccessChainOp op, int sourceBits,
65  int targetBits, OpBuilder &builder) {
66  assert(targetBits % sourceBits == 0);
67  const auto loc = op.getLoc();
68  Value lastDim = op->getOperand(op.getNumOperands() - 1);
69  Type type = lastDim.getType();
70  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
71  auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
72  auto indices = llvm::to_vector<4>(op.getIndices());
73  // There are two elements if this is a 1-D tensor.
74  assert(indices.size() == 2);
75  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
76  Type t = typeConverter.convertType(op.getComponentPtr().getType());
77  return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
78 }
79 
80 /// Casts the given `srcBool` into an integer of `dstType`.
81 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
82  OpBuilder &builder) {
83  assert(srcBool.getType().isInteger(1));
84  if (dstType.isInteger(1))
85  return srcBool;
86  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
87  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
88  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
89 }
90 
91 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
92 /// to the type destination type, and masked.
93 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
94  OpBuilder &builder) {
95  IntegerType dstType = cast<IntegerType>(mask.getType());
96  int targetBits = static_cast<int>(dstType.getWidth());
97  int valueBits = value.getType().getIntOrFloatBitWidth();
98  assert(valueBits <= targetBits);
99 
100  if (valueBits == 1) {
101  value = castBoolToIntN(loc, value, dstType, builder);
102  } else {
103  if (valueBits < targetBits) {
104  value = builder.create<spirv::UConvertOp>(
105  loc, builder.getIntegerType(targetBits), value);
106  }
107 
108  value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
109  }
110  return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
111  offset);
112 }
113 
114 /// Returns true if the allocations of memref `type` generated from `allocOp`
115 /// can be lowered to SPIR-V.
116 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
117  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
118  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
119  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
120  return false;
121  } else if (isa<memref::AllocaOp>(allocOp)) {
122  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
123  if (!sc || sc.getValue() != spirv::StorageClass::Function)
124  return false;
125  } else {
126  return false;
127  }
128 
129  // Currently only support static shape and int or float or vector of int or
130  // float element type.
131  if (!type.hasStaticShape())
132  return false;
133 
134  Type elementType = type.getElementType();
135  if (auto vecType = dyn_cast<VectorType>(elementType))
136  elementType = vecType.getElementType();
137  return elementType.isIntOrFloat();
138 }
139 
140 /// Returns the scope to use for atomic operations use for emulating store
141 /// operations of unsupported integer bitwidths, based on the memref
142 /// type. Returns std::nullopt on failure.
143 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
144  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
145  switch (sc.getValue()) {
146  case spirv::StorageClass::StorageBuffer:
147  return spirv::Scope::Device;
148  case spirv::StorageClass::Workgroup:
149  return spirv::Scope::Workgroup;
150  default:
151  break;
152  }
153  return {};
154 }
155 
156 /// Casts the given `srcInt` into a boolean value.
157 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
158  if (srcInt.getType().isInteger(1))
159  return srcInt;
160 
161  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
162  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // Operation conversion
167 //===----------------------------------------------------------------------===//
168 
169 // Note that DRR cannot be used for the patterns in this file: we may need to
170 // convert type along the way, which requires ConversionPattern. DRR generates
171 // normal RewritePattern.
172 
173 namespace {
174 
175 /// Converts memref.alloca to SPIR-V Function variables.
176 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
177 public:
179 
181  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
182  ConversionPatternRewriter &rewriter) const override;
183 };
184 
185 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
186 /// to Workgroup memory when the size is constant. Note that this pattern needs
187 /// to be applied in a pass that runs at least at spirv.module scope since it
188 /// wil ladd global variables into the spirv.module.
189 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
190 public:
192 
194  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
195  ConversionPatternRewriter &rewriter) const override;
196 };
197 
198 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
199 class AtomicRMWOpPattern final
200  : public OpConversionPattern<memref::AtomicRMWOp> {
201 public:
203 
205  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
206  ConversionPatternRewriter &rewriter) const override;
207 };
208 
209 /// Removed a deallocation if it is a supported allocation. Currently only
210 /// removes deallocation if the memory space is workgroup memory.
211 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
212 public:
214 
216  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
217  ConversionPatternRewriter &rewriter) const override;
218 };
219 
220 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
221 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
222 public:
224 
226  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
227  ConversionPatternRewriter &rewriter) const override;
228 };
229 
230 /// Converts memref.load to spirv.Load + spirv.AccessChain.
231 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
232 public:
234 
236  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
237  ConversionPatternRewriter &rewriter) const override;
238 };
239 
240 /// Converts memref.store to spirv.Store on integers.
241 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
242 public:
244 
246  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const override;
248 };
249 
250 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
251 class MemorySpaceCastOpPattern final
252  : public OpConversionPattern<memref::MemorySpaceCastOp> {
253 public:
255 
257  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
258  ConversionPatternRewriter &rewriter) const override;
259 };
260 
261 /// Converts memref.store to spirv.Store.
262 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
263 public:
265 
267  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
268  ConversionPatternRewriter &rewriter) const override;
269 };
270 
271 class ReinterpretCastPattern final
272  : public OpConversionPattern<memref::ReinterpretCastOp> {
273 public:
275 
277  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
278  ConversionPatternRewriter &rewriter) const override;
279 };
280 
281 class CastPattern final : public OpConversionPattern<memref::CastOp> {
282 public:
284 
286  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
287  ConversionPatternRewriter &rewriter) const override {
288  Value src = adaptor.getSource();
289  Type srcType = src.getType();
290 
291  const TypeConverter *converter = getTypeConverter();
292  Type dstType = converter->convertType(op.getType());
293  if (srcType != dstType)
294  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
295  diag << "types doesn't match: " << srcType << " and " << dstType;
296  });
297 
298  rewriter.replaceOp(op, src);
299  return success();
300  }
301 };
302 
303 } // namespace
304 
305 //===----------------------------------------------------------------------===//
306 // AllocaOp
307 //===----------------------------------------------------------------------===//
308 
310 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const {
312  MemRefType allocType = allocaOp.getType();
313  if (!isAllocationSupported(allocaOp, allocType))
314  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
315 
316  // Get the SPIR-V type for the allocation.
317  Type spirvType = getTypeConverter()->convertType(allocType);
318  if (!spirvType)
319  return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
320 
321  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
322  spirv::StorageClass::Function,
323  /*initializer=*/nullptr);
324  return success();
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // AllocOp
329 //===----------------------------------------------------------------------===//
330 
332 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
333  ConversionPatternRewriter &rewriter) const {
334  MemRefType allocType = operation.getType();
335  if (!isAllocationSupported(operation, allocType))
336  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
337 
338  // Get the SPIR-V type for the allocation.
339  Type spirvType = getTypeConverter()->convertType(allocType);
340  if (!spirvType)
341  return rewriter.notifyMatchFailure(operation, "type conversion failed");
342 
343  // Insert spirv.GlobalVariable for this allocation.
344  Operation *parent =
345  SymbolTable::getNearestSymbolTable(operation->getParentOp());
346  if (!parent)
347  return failure();
348  Location loc = operation.getLoc();
349  spirv::GlobalVariableOp varOp;
350  {
351  OpBuilder::InsertionGuard guard(rewriter);
352  Block &entryBlock = *parent->getRegion(0).begin();
353  rewriter.setInsertionPointToStart(&entryBlock);
354  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
355  std::string varName =
356  std::string("__workgroup_mem__") +
357  std::to_string(std::distance(varOps.begin(), varOps.end()));
358  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
359  /*initializer=*/nullptr);
360  }
361 
362  // Get pointer to global variable at the current scope.
363  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
364  return success();
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // AllocOp
369 //===----------------------------------------------------------------------===//
370 
372 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
373  OpAdaptor adaptor,
374  ConversionPatternRewriter &rewriter) const {
375  if (isa<FloatType>(atomicOp.getType()))
376  return rewriter.notifyMatchFailure(atomicOp,
377  "unimplemented floating-point case");
378 
379  auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
380  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
381  if (!scope)
382  return rewriter.notifyMatchFailure(atomicOp,
383  "unsupported memref memory space");
384 
385  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
386  Type resultType = typeConverter.convertType(atomicOp.getType());
387  if (!resultType)
388  return rewriter.notifyMatchFailure(atomicOp,
389  "failed to convert result type");
390 
391  auto loc = atomicOp.getLoc();
392  Value ptr =
393  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
394  adaptor.getIndices(), loc, rewriter);
395 
396  if (!ptr)
397  return failure();
398 
399 #define ATOMIC_CASE(kind, spirvOp) \
400  case arith::AtomicRMWKind::kind: \
401  rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
402  atomicOp, resultType, ptr, *scope, \
403  spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
404  break
405 
406  switch (atomicOp.getKind()) {
407  ATOMIC_CASE(addi, AtomicIAddOp);
408  ATOMIC_CASE(maxs, AtomicSMaxOp);
409  ATOMIC_CASE(maxu, AtomicUMaxOp);
410  ATOMIC_CASE(mins, AtomicSMinOp);
411  ATOMIC_CASE(minu, AtomicUMinOp);
412  ATOMIC_CASE(ori, AtomicOrOp);
413  ATOMIC_CASE(andi, AtomicAndOp);
414  default:
415  return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
416  }
417 
418 #undef ATOMIC_CASE
419 
420  return success();
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // DeallocOp
425 //===----------------------------------------------------------------------===//
426 
428 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
429  OpAdaptor adaptor,
430  ConversionPatternRewriter &rewriter) const {
431  MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
432  if (!isAllocationSupported(operation, deallocType))
433  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
434  rewriter.eraseOp(operation);
435  return success();
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // LoadOp
440 //===----------------------------------------------------------------------===//
441 
443 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const {
445  auto loc = loadOp.getLoc();
446  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
447  if (!memrefType.getElementType().isSignlessInteger())
448  return failure();
449 
450  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
451  Value accessChain =
452  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
453  adaptor.getIndices(), loc, rewriter);
454 
455  if (!accessChain)
456  return failure();
457 
458  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
459  bool isBool = srcBits == 1;
460  if (isBool)
461  srcBits = typeConverter.getOptions().boolNumBits;
462 
463  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
464  if (!pointerType)
465  return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
466 
467  Type pointeeType = pointerType.getPointeeType();
468  Type dstType;
469  if (typeConverter.allows(spirv::Capability::Kernel)) {
470  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
471  dstType = arrayType.getElementType();
472  else
473  dstType = pointeeType;
474  } else {
475  // For Vulkan we need to extract element from wrapping struct and array.
476  Type structElemType =
477  cast<spirv::StructType>(pointeeType).getElementType(0);
478  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
479  dstType = arrayType.getElementType();
480  else
481  dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
482  }
483  int dstBits = dstType.getIntOrFloatBitWidth();
484  assert(dstBits % srcBits == 0);
485 
486  // If the rewritten load op has the same bit width, use the loading value
487  // directly.
488  if (srcBits == dstBits) {
489  Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
490  if (isBool)
491  loadVal = castIntNToBool(loc, loadVal, rewriter);
492  rewriter.replaceOp(loadOp, loadVal);
493  return success();
494  }
495 
496  // Bitcasting is currently unsupported for Kernel capability /
497  // spirv.PtrAccessChain.
498  if (typeConverter.allows(spirv::Capability::Kernel))
499  return failure();
500 
501  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
502  if (!accessChainOp)
503  return failure();
504 
505  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
506  // still returns a linearized accessing. If the accessing is not linearized,
507  // there will be offset issues.
508  assert(accessChainOp.getIndices().size() == 2);
509  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
510  srcBits, dstBits, rewriter);
511  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
512  loc, dstType, adjustedPtr,
513  loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
514  spirv::attributeName<spirv::MemoryAccess>()),
515  loadOp->getAttrOfType<IntegerAttr>("alignment"));
516 
517  // Shift the bits to the rightmost.
518  // ____XXXX________ -> ____________XXXX
519  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
520  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
521  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
522  loc, spvLoadOp.getType(), spvLoadOp, offset);
523 
524  // Apply the mask to extract corresponding bits.
525  Value mask = rewriter.create<spirv::ConstantOp>(
526  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
527  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
528 
529  // Apply sign extension on the loading value unconditionally. The signedness
530  // semantic is carried in the operator itself, we relies other pattern to
531  // handle the casting.
532  IntegerAttr shiftValueAttr =
533  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
534  Value shiftValue =
535  rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
536  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
537  shiftValue);
538  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
539  shiftValue);
540 
541  rewriter.replaceOp(loadOp, result);
542 
543  assert(accessChainOp.use_empty());
544  rewriter.eraseOp(accessChainOp);
545 
546  return success();
547 }
548 
550 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
551  ConversionPatternRewriter &rewriter) const {
552  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
553  if (memrefType.getElementType().isSignlessInteger())
554  return failure();
555  auto loadPtr = spirv::getElementPtr(
556  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
557  adaptor.getIndices(), loadOp.getLoc(), rewriter);
558 
559  if (!loadPtr)
560  return failure();
561 
562  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
563  return success();
564 }
565 
567 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
568  ConversionPatternRewriter &rewriter) const {
569  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
570  if (!memrefType.getElementType().isSignlessInteger())
571  return rewriter.notifyMatchFailure(storeOp,
572  "element type is not a signless int");
573 
574  auto loc = storeOp.getLoc();
575  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
576  Value accessChain =
577  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
578  adaptor.getIndices(), loc, rewriter);
579 
580  if (!accessChain)
581  return rewriter.notifyMatchFailure(
582  storeOp, "failed to convert element pointer type");
583 
584  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
585 
586  bool isBool = srcBits == 1;
587  if (isBool)
588  srcBits = typeConverter.getOptions().boolNumBits;
589 
590  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
591  if (!pointerType)
592  return rewriter.notifyMatchFailure(storeOp,
593  "failed to convert memref type");
594 
595  Type pointeeType = pointerType.getPointeeType();
596  IntegerType dstType;
597  if (typeConverter.allows(spirv::Capability::Kernel)) {
598  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
599  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
600  else
601  dstType = dyn_cast<IntegerType>(pointeeType);
602  } else {
603  // For Vulkan we need to extract element from wrapping struct and array.
604  Type structElemType =
605  cast<spirv::StructType>(pointeeType).getElementType(0);
606  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
607  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
608  else
609  dstType = dyn_cast<IntegerType>(
610  cast<spirv::RuntimeArrayType>(structElemType).getElementType());
611  }
612 
613  if (!dstType)
614  return rewriter.notifyMatchFailure(
615  storeOp, "failed to determine destination element type");
616 
617  int dstBits = static_cast<int>(dstType.getWidth());
618  assert(dstBits % srcBits == 0);
619 
620  if (srcBits == dstBits) {
621  Value storeVal = adaptor.getValue();
622  if (isBool)
623  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
624  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
625  return success();
626  }
627 
628  // Bitcasting is currently unsupported for Kernel capability /
629  // spirv.PtrAccessChain.
630  if (typeConverter.allows(spirv::Capability::Kernel))
631  return failure();
632 
633  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
634  if (!accessChainOp)
635  return failure();
636 
637  // Since there are multiple threads in the processing, the emulation will be
638  // done with atomic operations. E.g., if the stored value is i8, rewrite the
639  // StoreOp to:
640  // 1) load a 32-bit integer
641  // 2) clear 8 bits in the loaded value
642  // 3) set 8 bits in the loaded value
643  // 4) store 32-bit value back
644  //
645  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
646  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
647  // step.
648  assert(accessChainOp.getIndices().size() == 2);
649  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
650  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
651 
652  // Create a mask to clear the destination. E.g., if it is the second i8 in
653  // i32, 0xFFFF00FF is created.
654  Value mask = rewriter.create<spirv::ConstantOp>(
655  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
656  Value clearBitsMask =
657  rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
658  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
659 
660  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
661  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
662  srcBits, dstBits, rewriter);
663  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
664  if (!scope)
665  return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
666 
667  Value result = rewriter.create<spirv::AtomicAndOp>(
668  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
669  clearBitsMask);
670  result = rewriter.create<spirv::AtomicOrOp>(
671  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
672  storeVal);
673 
674  // The AtomicOrOp has no side effect. Since it is already inserted, we can
675  // just remove the original StoreOp. Note that rewriter.replaceOp()
676  // doesn't work because it only accepts that the numbers of result are the
677  // same.
678  rewriter.eraseOp(storeOp);
679 
680  assert(accessChainOp.use_empty());
681  rewriter.eraseOp(accessChainOp);
682 
683  return success();
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // MemorySpaceCastOp
688 //===----------------------------------------------------------------------===//
689 
690 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
691  memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
692  ConversionPatternRewriter &rewriter) const {
693  Location loc = addrCastOp.getLoc();
694  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
695  if (!typeConverter.allows(spirv::Capability::Kernel))
696  return rewriter.notifyMatchFailure(
697  loc, "address space casts require kernel capability");
698 
699  auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
700  if (!sourceType)
701  return rewriter.notifyMatchFailure(
702  loc, "SPIR-V lowering requires ranked memref types");
703  auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
704 
705  auto sourceStorageClassAttr =
706  dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
707  if (!sourceStorageClassAttr)
708  return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
709  diag << "source address space " << sourceType.getMemorySpace()
710  << " must be a SPIR-V storage class";
711  });
712  auto resultStorageClassAttr =
713  dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
714  if (!resultStorageClassAttr)
715  return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
716  diag << "result address space " << resultType.getMemorySpace()
717  << " must be a SPIR-V storage class";
718  });
719 
720  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
721  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
722 
723  Value result = adaptor.getSource();
724  Type resultPtrType = typeConverter.convertType(resultType);
725  if (!resultPtrType)
726  return rewriter.notifyMatchFailure(addrCastOp,
727  "failed to convert memref type");
728 
729  Type genericPtrType = resultPtrType;
730  // SPIR-V doesn't have a general address space cast operation. Instead, it has
731  // conversions to and from generic pointers. To implement the general case,
732  // we use specific-to-generic conversions when the source class is not
733  // generic. Then when the result storage class is not generic, we convert the
734  // generic pointer (either the input on ar intermediate result) to that
735  // class. This also means that we'll need the intermediate generic pointer
736  // type if neither the source or destination have it.
737  if (sourceSc != spirv::StorageClass::Generic &&
738  resultSc != spirv::StorageClass::Generic) {
739  Type intermediateType =
740  MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
741  sourceType.getLayout(),
742  rewriter.getAttr<spirv::StorageClassAttr>(
743  spirv::StorageClass::Generic));
744  genericPtrType = typeConverter.convertType(intermediateType);
745  }
746  if (sourceSc != spirv::StorageClass::Generic) {
747  result =
748  rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
749  }
750  if (resultSc != spirv::StorageClass::Generic) {
751  result =
752  rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
753  }
754  rewriter.replaceOp(addrCastOp, result);
755  return success();
756 }
757 
759 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
760  ConversionPatternRewriter &rewriter) const {
761  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
762  if (memrefType.getElementType().isSignlessInteger())
763  return rewriter.notifyMatchFailure(storeOp, "signless int");
764  auto storePtr = spirv::getElementPtr(
765  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
766  adaptor.getIndices(), storeOp.getLoc(), rewriter);
767 
768  if (!storePtr)
769  return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
770 
771  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
772  adaptor.getValue());
773  return success();
774 }
775 
776 LogicalResult ReinterpretCastPattern::matchAndRewrite(
777  memref::ReinterpretCastOp op, OpAdaptor adaptor,
778  ConversionPatternRewriter &rewriter) const {
779  Value src = adaptor.getSource();
780  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
781 
782  if (!srcType)
783  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
784  diag << "invalid src type " << src.getType();
785  });
786 
787  const TypeConverter *converter = getTypeConverter();
788 
789  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
790  if (dstType != srcType)
791  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
792  diag << "invalid dst type " << op.getType();
793  });
794 
795  OpFoldResult offset =
796  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
797  .front();
798  if (isConstantIntValue(offset, 0)) {
799  rewriter.replaceOp(op, src);
800  return success();
801  }
802 
803  Type intType = converter->convertType(rewriter.getIndexType());
804  if (!intType)
805  return rewriter.notifyMatchFailure(op, "failed to convert index type");
806 
807  Location loc = op.getLoc();
808  auto offsetValue = [&]() -> Value {
809  if (auto val = dyn_cast<Value>(offset))
810  return val;
811 
812  int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
813  Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
814  return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
815  }();
816 
817  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
818  op, src, offsetValue, std::nullopt);
819  return success();
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // Pattern population
824 //===----------------------------------------------------------------------===//
825 
826 namespace mlir {
828  RewritePatternSet &patterns) {
829  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
830  DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
831  LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
832  ReinterpretCastPattern, CastPattern>(typeConverter,
833  patterns.getContext());
834 }
835 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:186
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IndexType getIndexType()
Definition: Builders.cpp:71
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
iterator begin()
Definition: Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
This header declares functions that assist transformations in the MemRef dialect.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26