MLIR  20.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include "llvm/Support/Debug.h"
25 #include <cassert>
26 #include <optional>
27 
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the offset of the value in `targetBits` representation.
37 ///
38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39 /// It's assumed to be non-negative.
40 ///
41 /// When accessing an element in the array treating as having elements of
42 /// `targetBits`, multiple values are loaded in the same time. The method
43 /// returns the offset where the `srcIdx` locates in the value. For example, if
44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
46 /// element has 8 bits.
47 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48  int targetBits, OpBuilder &builder) {
49  assert(targetBits % sourceBits == 0);
50  Type type = srcIdx.getType();
51  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54  auto srcBitsValue =
55  builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58 }
59 
60 /// Returns an adjusted spirv::AccessChainOp. Based on the
61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62 /// supported. During conversion if a memref of an unsupported type is used,
63 /// load/stores to this memref need to be modified to use a supported higher
64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66 /// the bits needed. The extraction of the actual bits needed are handled
67 /// separately. Note that this only works for a 1-D tensor.
68 static Value
70  spirv::AccessChainOp op, int sourceBits,
71  int targetBits, OpBuilder &builder) {
72  assert(targetBits % sourceBits == 0);
73  const auto loc = op.getLoc();
74  Value lastDim = op->getOperand(op.getNumOperands() - 1);
75  Type type = lastDim.getType();
76  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78  auto indices = llvm::to_vector<4>(op.getIndices());
79  // There are two elements if this is a 1-D tensor.
80  assert(indices.size() == 2);
81  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82  Type t = typeConverter.convertType(op.getComponentPtr().getType());
83  return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
84 }
85 
86 /// Casts the given `srcBool` into an integer of `dstType`.
87 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
88  OpBuilder &builder) {
89  assert(srcBool.getType().isInteger(1));
90  if (dstType.isInteger(1))
91  return srcBool;
92  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
93  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
95  zero);
96 }
97 
98 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
99 /// to the type destination type, and masked.
100 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
101  OpBuilder &builder) {
102  IntegerType dstType = cast<IntegerType>(mask.getType());
103  int targetBits = static_cast<int>(dstType.getWidth());
104  int valueBits = value.getType().getIntOrFloatBitWidth();
105  assert(valueBits <= targetBits);
106 
107  if (valueBits == 1) {
108  value = castBoolToIntN(loc, value, dstType, builder);
109  } else {
110  if (valueBits < targetBits) {
111  value = builder.create<spirv::UConvertOp>(
112  loc, builder.getIntegerType(targetBits), value);
113  }
114 
115  value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
116  }
117  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
118  value, offset);
119 }
120 
121 /// Returns true if the allocations of memref `type` generated from `allocOp`
122 /// can be lowered to SPIR-V.
123 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
124  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
126  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
127  return false;
128  } else if (isa<memref::AllocaOp>(allocOp)) {
129  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
130  if (!sc || sc.getValue() != spirv::StorageClass::Function)
131  return false;
132  } else {
133  return false;
134  }
135 
136  // Currently only support static shape and int or float or vector of int or
137  // float element type.
138  if (!type.hasStaticShape())
139  return false;
140 
141  Type elementType = type.getElementType();
142  if (auto vecType = dyn_cast<VectorType>(elementType))
143  elementType = vecType.getElementType();
144  return elementType.isIntOrFloat();
145 }
146 
147 /// Returns the scope to use for atomic operations use for emulating store
148 /// operations of unsupported integer bitwidths, based on the memref
149 /// type. Returns std::nullopt on failure.
150 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
151  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
152  switch (sc.getValue()) {
153  case spirv::StorageClass::StorageBuffer:
154  return spirv::Scope::Device;
155  case spirv::StorageClass::Workgroup:
156  return spirv::Scope::Workgroup;
157  default:
158  break;
159  }
160  return {};
161 }
162 
163 /// Casts the given `srcInt` into a boolean value.
164 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165  if (srcInt.getType().isInteger(1))
166  return srcInt;
167 
168  auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
169  return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Operation conversion
174 //===----------------------------------------------------------------------===//
175 
176 // Note that DRR cannot be used for the patterns in this file: we may need to
177 // convert type along the way, which requires ConversionPattern. DRR generates
178 // normal RewritePattern.
179 
180 namespace {
181 
182 /// Converts memref.alloca to SPIR-V Function variables.
183 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
184 public:
186 
187  LogicalResult
188  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
189  ConversionPatternRewriter &rewriter) const override;
190 };
191 
192 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
193 /// to Workgroup memory when the size is constant. Note that this pattern needs
194 /// to be applied in a pass that runs at least at spirv.module scope since it
195 /// wil ladd global variables into the spirv.module.
196 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
197 public:
199 
200  LogicalResult
201  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override;
203 };
204 
205 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
206 class AtomicRMWOpPattern final
207  : public OpConversionPattern<memref::AtomicRMWOp> {
208 public:
210 
211  LogicalResult
212  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override;
214 };
215 
216 /// Removed a deallocation if it is a supported allocation. Currently only
217 /// removes deallocation if the memory space is workgroup memory.
218 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
219 public:
221 
222  LogicalResult
223  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const override;
225 };
226 
227 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
228 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
229 public:
231 
232  LogicalResult
233  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
234  ConversionPatternRewriter &rewriter) const override;
235 };
236 
237 /// Converts memref.load to spirv.Load + spirv.AccessChain.
238 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
239 public:
241 
242  LogicalResult
243  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
244  ConversionPatternRewriter &rewriter) const override;
245 };
246 
247 /// Converts memref.store to spirv.Store on integers.
248 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249 public:
251 
252  LogicalResult
253  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
254  ConversionPatternRewriter &rewriter) const override;
255 };
256 
257 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
258 class MemorySpaceCastOpPattern final
259  : public OpConversionPattern<memref::MemorySpaceCastOp> {
260 public:
262 
263  LogicalResult
264  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const override;
266 };
267 
268 /// Converts memref.store to spirv.Store.
269 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
270 public:
272 
273  LogicalResult
274  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
275  ConversionPatternRewriter &rewriter) const override;
276 };
277 
278 class ReinterpretCastPattern final
279  : public OpConversionPattern<memref::ReinterpretCastOp> {
280 public:
282 
283  LogicalResult
284  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
285  ConversionPatternRewriter &rewriter) const override;
286 };
287 
288 class CastPattern final : public OpConversionPattern<memref::CastOp> {
289 public:
291 
292  LogicalResult
293  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
294  ConversionPatternRewriter &rewriter) const override {
295  Value src = adaptor.getSource();
296  Type srcType = src.getType();
297 
298  const TypeConverter *converter = getTypeConverter();
299  Type dstType = converter->convertType(op.getType());
300  if (srcType != dstType)
301  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
302  diag << "types doesn't match: " << srcType << " and " << dstType;
303  });
304 
305  rewriter.replaceOp(op, src);
306  return success();
307  }
308 };
309 
310 } // namespace
311 
312 //===----------------------------------------------------------------------===//
313 // AllocaOp
314 //===----------------------------------------------------------------------===//
315 
316 LogicalResult
317 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
318  ConversionPatternRewriter &rewriter) const {
319  MemRefType allocType = allocaOp.getType();
320  if (!isAllocationSupported(allocaOp, allocType))
321  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
322 
323  // Get the SPIR-V type for the allocation.
324  Type spirvType = getTypeConverter()->convertType(allocType);
325  if (!spirvType)
326  return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
327 
328  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
329  spirv::StorageClass::Function,
330  /*initializer=*/nullptr);
331  return success();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // AllocOp
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult
339 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const {
341  MemRefType allocType = operation.getType();
342  if (!isAllocationSupported(operation, allocType))
343  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
344 
345  // Get the SPIR-V type for the allocation.
346  Type spirvType = getTypeConverter()->convertType(allocType);
347  if (!spirvType)
348  return rewriter.notifyMatchFailure(operation, "type conversion failed");
349 
350  // Insert spirv.GlobalVariable for this allocation.
351  Operation *parent =
352  SymbolTable::getNearestSymbolTable(operation->getParentOp());
353  if (!parent)
354  return failure();
355  Location loc = operation.getLoc();
356  spirv::GlobalVariableOp varOp;
357  {
358  OpBuilder::InsertionGuard guard(rewriter);
359  Block &entryBlock = *parent->getRegion(0).begin();
360  rewriter.setInsertionPointToStart(&entryBlock);
361  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
362  std::string varName =
363  std::string("__workgroup_mem__") +
364  std::to_string(std::distance(varOps.begin(), varOps.end()));
365  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
366  /*initializer=*/nullptr);
367  }
368 
369  // Get pointer to global variable at the current scope.
370  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
371  return success();
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // AllocOp
376 //===----------------------------------------------------------------------===//
377 
378 LogicalResult
379 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
380  OpAdaptor adaptor,
381  ConversionPatternRewriter &rewriter) const {
382  if (isa<FloatType>(atomicOp.getType()))
383  return rewriter.notifyMatchFailure(atomicOp,
384  "unimplemented floating-point case");
385 
386  auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
387  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
388  if (!scope)
389  return rewriter.notifyMatchFailure(atomicOp,
390  "unsupported memref memory space");
391 
392  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
393  Type resultType = typeConverter.convertType(atomicOp.getType());
394  if (!resultType)
395  return rewriter.notifyMatchFailure(atomicOp,
396  "failed to convert result type");
397 
398  auto loc = atomicOp.getLoc();
399  Value ptr =
400  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
401  adaptor.getIndices(), loc, rewriter);
402 
403  if (!ptr)
404  return failure();
405 
406 #define ATOMIC_CASE(kind, spirvOp) \
407  case arith::AtomicRMWKind::kind: \
408  rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
409  atomicOp, resultType, ptr, *scope, \
410  spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
411  break
412 
413  switch (atomicOp.getKind()) {
414  ATOMIC_CASE(addi, AtomicIAddOp);
415  ATOMIC_CASE(maxs, AtomicSMaxOp);
416  ATOMIC_CASE(maxu, AtomicUMaxOp);
417  ATOMIC_CASE(mins, AtomicSMinOp);
418  ATOMIC_CASE(minu, AtomicUMinOp);
419  ATOMIC_CASE(ori, AtomicOrOp);
420  ATOMIC_CASE(andi, AtomicAndOp);
421  default:
422  return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
423  }
424 
425 #undef ATOMIC_CASE
426 
427  return success();
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // DeallocOp
432 //===----------------------------------------------------------------------===//
433 
434 LogicalResult
435 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
436  OpAdaptor adaptor,
437  ConversionPatternRewriter &rewriter) const {
438  MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
439  if (!isAllocationSupported(operation, deallocType))
440  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
441  rewriter.eraseOp(operation);
442  return success();
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // LoadOp
447 //===----------------------------------------------------------------------===//
448 
450  spirv::MemoryAccessAttr memoryAccess;
451  IntegerAttr alignment;
452 };
453 
454 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
455 /// any.
456 static FailureOr<MemoryRequirements>
457 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
458  MLIRContext *ctx = accessedPtr.getContext();
459 
460  auto memoryAccess = spirv::MemoryAccess::None;
461  if (isNontemporal) {
462  memoryAccess = spirv::MemoryAccess::Nontemporal;
463  }
464 
465  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
466  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
467  if (memoryAccess == spirv::MemoryAccess::None) {
468  return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
469  }
470  return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
471  IntegerAttr{}};
472  }
473 
474  // PhysicalStorageBuffers require the `Aligned` attribute.
475  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
476  if (!pointeeType)
477  return failure();
478 
479  // For scalar types, the alignment is determined by their size.
480  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
481  if (!sizeInBytes.has_value())
482  return failure();
483 
484  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
485  auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
486  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
487  return MemoryRequirements{memAccessAttr, alignment};
488 }
489 
490 /// Given an accessed SPIR-V pointer and the original memref load/store
491 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
492 /// account the alignment attributes applied to the load/store op.
493 template <class LoadOrStoreOp>
494 static FailureOr<MemoryRequirements>
495 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
496  static_assert(
497  llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
498  "Must be called on either memref::LoadOp or memref::StoreOp");
499 
500  Operation *memrefAccessOp = loadOrStoreOp.getOperation();
501  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
502  spirv::attributeName<spirv::MemoryAccess>());
503  auto memrefAlignment =
504  memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
505  if (memrefMemAccess && memrefAlignment)
506  return MemoryRequirements{memrefMemAccess, memrefAlignment};
507 
508  return calculateMemoryRequirements(accessedPtr,
509  loadOrStoreOp.getNontemporal());
510 }
511 
512 LogicalResult
513 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
514  ConversionPatternRewriter &rewriter) const {
515  auto loc = loadOp.getLoc();
516  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
517  if (!memrefType.getElementType().isSignlessInteger())
518  return failure();
519 
520  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
521  Value accessChain =
522  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
523  adaptor.getIndices(), loc, rewriter);
524 
525  if (!accessChain)
526  return failure();
527 
528  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
529  bool isBool = srcBits == 1;
530  if (isBool)
531  srcBits = typeConverter.getOptions().boolNumBits;
532 
533  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
534  if (!pointerType)
535  return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
536 
537  Type pointeeType = pointerType.getPointeeType();
538  Type dstType;
539  if (typeConverter.allows(spirv::Capability::Kernel)) {
540  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
541  dstType = arrayType.getElementType();
542  else
543  dstType = pointeeType;
544  } else {
545  // For Vulkan we need to extract element from wrapping struct and array.
546  Type structElemType =
547  cast<spirv::StructType>(pointeeType).getElementType(0);
548  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
549  dstType = arrayType.getElementType();
550  else
551  dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
552  }
553  int dstBits = dstType.getIntOrFloatBitWidth();
554  assert(dstBits % srcBits == 0);
555 
556  // If the rewritten load op has the same bit width, use the loading value
557  // directly.
558  if (srcBits == dstBits) {
559  auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
560  if (failed(memoryRequirements))
561  return rewriter.notifyMatchFailure(
562  loadOp, "failed to determine memory requirements");
563 
564  auto [memoryAccess, alignment] = *memoryRequirements;
565  Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
566  memoryAccess, alignment);
567  if (isBool)
568  loadVal = castIntNToBool(loc, loadVal, rewriter);
569  rewriter.replaceOp(loadOp, loadVal);
570  return success();
571  }
572 
573  // Bitcasting is currently unsupported for Kernel capability /
574  // spirv.PtrAccessChain.
575  if (typeConverter.allows(spirv::Capability::Kernel))
576  return failure();
577 
578  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
579  if (!accessChainOp)
580  return failure();
581 
582  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
583  // still returns a linearized accessing. If the accessing is not linearized,
584  // there will be offset issues.
585  assert(accessChainOp.getIndices().size() == 2);
586  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
587  srcBits, dstBits, rewriter);
588  auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
589  if (failed(memoryRequirements))
590  return rewriter.notifyMatchFailure(
591  loadOp, "failed to determine memory requirements");
592 
593  auto [memoryAccess, alignment] = *memoryRequirements;
594  Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
595  memoryAccess, alignment);
596 
597  // Shift the bits to the rightmost.
598  // ____XXXX________ -> ____________XXXX
599  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
600  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
601  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
602  loc, spvLoadOp.getType(), spvLoadOp, offset);
603 
604  // Apply the mask to extract corresponding bits.
605  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
606  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
607  result =
608  rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
609 
610  // Apply sign extension on the loading value unconditionally. The signedness
611  // semantic is carried in the operator itself, we relies other pattern to
612  // handle the casting.
613  IntegerAttr shiftValueAttr =
614  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
615  Value shiftValue =
616  rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
617  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
618  result, shiftValue);
619  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
620  loc, dstType, result, shiftValue);
621 
622  rewriter.replaceOp(loadOp, result);
623 
624  assert(accessChainOp.use_empty());
625  rewriter.eraseOp(accessChainOp);
626 
627  return success();
628 }
629 
630 LogicalResult
631 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const {
633  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
634  if (memrefType.getElementType().isSignlessInteger())
635  return failure();
636  Value loadPtr = spirv::getElementPtr(
637  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
638  adaptor.getIndices(), loadOp.getLoc(), rewriter);
639 
640  if (!loadPtr)
641  return failure();
642 
643  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
644  if (failed(memoryRequirements))
645  return rewriter.notifyMatchFailure(
646  loadOp, "failed to determine memory requirements");
647 
648  auto [memoryAccess, alignment] = *memoryRequirements;
649  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
650  alignment);
651  return success();
652 }
653 
654 LogicalResult
655 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
656  ConversionPatternRewriter &rewriter) const {
657  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
658  if (!memrefType.getElementType().isSignlessInteger())
659  return rewriter.notifyMatchFailure(storeOp,
660  "element type is not a signless int");
661 
662  auto loc = storeOp.getLoc();
663  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
664  Value accessChain =
665  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
666  adaptor.getIndices(), loc, rewriter);
667 
668  if (!accessChain)
669  return rewriter.notifyMatchFailure(
670  storeOp, "failed to convert element pointer type");
671 
672  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
673 
674  bool isBool = srcBits == 1;
675  if (isBool)
676  srcBits = typeConverter.getOptions().boolNumBits;
677 
678  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
679  if (!pointerType)
680  return rewriter.notifyMatchFailure(storeOp,
681  "failed to convert memref type");
682 
683  Type pointeeType = pointerType.getPointeeType();
684  IntegerType dstType;
685  if (typeConverter.allows(spirv::Capability::Kernel)) {
686  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
687  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
688  else
689  dstType = dyn_cast<IntegerType>(pointeeType);
690  } else {
691  // For Vulkan we need to extract element from wrapping struct and array.
692  Type structElemType =
693  cast<spirv::StructType>(pointeeType).getElementType(0);
694  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
695  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
696  else
697  dstType = dyn_cast<IntegerType>(
698  cast<spirv::RuntimeArrayType>(structElemType).getElementType());
699  }
700 
701  if (!dstType)
702  return rewriter.notifyMatchFailure(
703  storeOp, "failed to determine destination element type");
704 
705  int dstBits = static_cast<int>(dstType.getWidth());
706  assert(dstBits % srcBits == 0);
707 
708  if (srcBits == dstBits) {
709  auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
710  if (failed(memoryRequirements))
711  return rewriter.notifyMatchFailure(
712  storeOp, "failed to determine memory requirements");
713 
714  auto [memoryAccess, alignment] = *memoryRequirements;
715  Value storeVal = adaptor.getValue();
716  if (isBool)
717  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
718  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
719  memoryAccess, alignment);
720  return success();
721  }
722 
723  // Bitcasting is currently unsupported for Kernel capability /
724  // spirv.PtrAccessChain.
725  if (typeConverter.allows(spirv::Capability::Kernel))
726  return failure();
727 
728  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
729  if (!accessChainOp)
730  return failure();
731 
732  // Since there are multiple threads in the processing, the emulation will be
733  // done with atomic operations. E.g., if the stored value is i8, rewrite the
734  // StoreOp to:
735  // 1) load a 32-bit integer
736  // 2) clear 8 bits in the loaded value
737  // 3) set 8 bits in the loaded value
738  // 4) store 32-bit value back
739  //
740  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
741  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
742  // step.
743  assert(accessChainOp.getIndices().size() == 2);
744  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
745  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
746 
747  // Create a mask to clear the destination. E.g., if it is the second i8 in
748  // i32, 0xFFFF00FF is created.
749  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
750  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
751  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
752  loc, dstType, mask, offset);
753  clearBitsMask =
754  rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
755 
756  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
757  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
758  srcBits, dstBits, rewriter);
759  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
760  if (!scope)
761  return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
762 
763  Value result = rewriter.create<spirv::AtomicAndOp>(
764  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
765  clearBitsMask);
766  result = rewriter.create<spirv::AtomicOrOp>(
767  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
768  storeVal);
769 
770  // The AtomicOrOp has no side effect. Since it is already inserted, we can
771  // just remove the original StoreOp. Note that rewriter.replaceOp()
772  // doesn't work because it only accepts that the numbers of result are the
773  // same.
774  rewriter.eraseOp(storeOp);
775 
776  assert(accessChainOp.use_empty());
777  rewriter.eraseOp(accessChainOp);
778 
779  return success();
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // MemorySpaceCastOp
784 //===----------------------------------------------------------------------===//
785 
786 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
787  memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
788  ConversionPatternRewriter &rewriter) const {
789  Location loc = addrCastOp.getLoc();
790  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
791  if (!typeConverter.allows(spirv::Capability::Kernel))
792  return rewriter.notifyMatchFailure(
793  loc, "address space casts require kernel capability");
794 
795  auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
796  if (!sourceType)
797  return rewriter.notifyMatchFailure(
798  loc, "SPIR-V lowering requires ranked memref types");
799  auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
800 
801  auto sourceStorageClassAttr =
802  dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
803  if (!sourceStorageClassAttr)
804  return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
805  diag << "source address space " << sourceType.getMemorySpace()
806  << " must be a SPIR-V storage class";
807  });
808  auto resultStorageClassAttr =
809  dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
810  if (!resultStorageClassAttr)
811  return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
812  diag << "result address space " << resultType.getMemorySpace()
813  << " must be a SPIR-V storage class";
814  });
815 
816  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
817  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
818 
819  Value result = adaptor.getSource();
820  Type resultPtrType = typeConverter.convertType(resultType);
821  if (!resultPtrType)
822  return rewriter.notifyMatchFailure(addrCastOp,
823  "failed to convert memref type");
824 
825  Type genericPtrType = resultPtrType;
826  // SPIR-V doesn't have a general address space cast operation. Instead, it has
827  // conversions to and from generic pointers. To implement the general case,
828  // we use specific-to-generic conversions when the source class is not
829  // generic. Then when the result storage class is not generic, we convert the
830  // generic pointer (either the input on ar intermediate result) to that
831  // class. This also means that we'll need the intermediate generic pointer
832  // type if neither the source or destination have it.
833  if (sourceSc != spirv::StorageClass::Generic &&
834  resultSc != spirv::StorageClass::Generic) {
835  Type intermediateType =
836  MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
837  sourceType.getLayout(),
838  rewriter.getAttr<spirv::StorageClassAttr>(
839  spirv::StorageClass::Generic));
840  genericPtrType = typeConverter.convertType(intermediateType);
841  }
842  if (sourceSc != spirv::StorageClass::Generic) {
843  result =
844  rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
845  }
846  if (resultSc != spirv::StorageClass::Generic) {
847  result =
848  rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
849  }
850  rewriter.replaceOp(addrCastOp, result);
851  return success();
852 }
853 
854 LogicalResult
855 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
856  ConversionPatternRewriter &rewriter) const {
857  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
858  if (memrefType.getElementType().isSignlessInteger())
859  return rewriter.notifyMatchFailure(storeOp, "signless int");
860  auto storePtr = spirv::getElementPtr(
861  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
862  adaptor.getIndices(), storeOp.getLoc(), rewriter);
863 
864  if (!storePtr)
865  return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
866 
867  auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
868  if (failed(memoryRequirements))
869  return rewriter.notifyMatchFailure(
870  storeOp, "failed to determine memory requirements");
871 
872  auto [memoryAccess, alignment] = *memoryRequirements;
873  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
874  storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
875  return success();
876 }
877 
878 LogicalResult ReinterpretCastPattern::matchAndRewrite(
879  memref::ReinterpretCastOp op, OpAdaptor adaptor,
880  ConversionPatternRewriter &rewriter) const {
881  Value src = adaptor.getSource();
882  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
883 
884  if (!srcType)
885  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
886  diag << "invalid src type " << src.getType();
887  });
888 
889  const TypeConverter *converter = getTypeConverter();
890 
891  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
892  if (dstType != srcType)
893  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
894  diag << "invalid dst type " << op.getType();
895  });
896 
897  OpFoldResult offset =
898  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
899  .front();
900  if (isConstantIntValue(offset, 0)) {
901  rewriter.replaceOp(op, src);
902  return success();
903  }
904 
905  Type intType = converter->convertType(rewriter.getIndexType());
906  if (!intType)
907  return rewriter.notifyMatchFailure(op, "failed to convert index type");
908 
909  Location loc = op.getLoc();
910  auto offsetValue = [&]() -> Value {
911  if (auto val = dyn_cast<Value>(offset))
912  return val;
913 
914  int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
915  Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
916  return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
917  }();
918 
919  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
920  op, src, offsetValue, std::nullopt);
921  return success();
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // Pattern population
926 //===----------------------------------------------------------------------===//
927 
928 namespace mlir {
930  RewritePatternSet &patterns) {
931  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
932  DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
933  LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
934  ReinterpretCastPattern, CastPattern>(typeConverter,
935  patterns.getContext());
936 }
937 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
@ None
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IndexType getIndexType()
Definition: Builders.cpp:95
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:106
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
iterator begin()
Definition: Region.h:55
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess