MLIR  22.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include <cassert>
25 #include <limits>
26 #include <optional>
27 
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the offset of the value in `targetBits` representation.
37 ///
38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39 /// It's assumed to be non-negative.
40 ///
41 /// When accessing an element in the array treating as having elements of
42 /// `targetBits`, multiple values are loaded in the same time. The method
43 /// returns the offset where the `srcIdx` locates in the value. For example, if
44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
46 /// element has 8 bits.
47 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48  int targetBits, OpBuilder &builder) {
49  assert(targetBits % sourceBits == 0);
50  Type type = srcIdx.getType();
51  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54  auto srcBitsValue =
55  builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58 }
59 
60 /// Returns an adjusted spirv::AccessChainOp. Based on the
61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62 /// supported. During conversion if a memref of an unsupported type is used,
63 /// load/stores to this memref need to be modified to use a supported higher
64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66 /// the bits needed. The extraction of the actual bits needed are handled
67 /// separately. Note that this only works for a 1-D tensor.
68 static Value
70  spirv::AccessChainOp op, int sourceBits,
71  int targetBits, OpBuilder &builder) {
72  assert(targetBits % sourceBits == 0);
73  const auto loc = op.getLoc();
74  Value lastDim = op->getOperand(op.getNumOperands() - 1);
75  Type type = lastDim.getType();
76  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78  auto indices = llvm::to_vector<4>(op.getIndices());
79  // There are two elements if this is a 1-D tensor.
80  assert(indices.size() == 2);
81  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82  Type t = typeConverter.convertType(op.getComponentPtr().getType());
83  return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
84  indices);
85 }
86 
87 /// Casts the given `srcBool` into an integer of `dstType`.
88 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89  OpBuilder &builder) {
90  assert(srcBool.getType().isInteger(1));
91  if (dstType.isInteger(1))
92  return srcBool;
93  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96  zero);
97 }
98 
99 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100 /// to the type destination type, and masked.
101 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102  OpBuilder &builder) {
103  IntegerType dstType = cast<IntegerType>(mask.getType());
104  int targetBits = static_cast<int>(dstType.getWidth());
105  int valueBits = value.getType().getIntOrFloatBitWidth();
106  assert(valueBits <= targetBits);
107 
108  if (valueBits == 1) {
109  value = castBoolToIntN(loc, value, dstType, builder);
110  } else {
111  if (valueBits < targetBits) {
112  value = spirv::UConvertOp::create(
113  builder, loc, builder.getIntegerType(targetBits), value);
114  }
115 
116  value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117  }
118  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119  value, offset);
120 }
121 
122 /// Returns true if the allocations of memref `type` generated from `allocOp`
123 /// can be lowered to SPIR-V.
124 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128  return false;
129  } else if (isa<memref::AllocaOp>(allocOp)) {
130  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131  if (!sc || sc.getValue() != spirv::StorageClass::Function)
132  return false;
133  } else {
134  return false;
135  }
136 
137  // Currently only support static shape and int or float or vector of int or
138  // float element type.
139  if (!type.hasStaticShape())
140  return false;
141 
142  Type elementType = type.getElementType();
143  if (auto vecType = dyn_cast<VectorType>(elementType))
144  elementType = vecType.getElementType();
145  return elementType.isIntOrFloat();
146 }
147 
148 /// Returns the scope to use for atomic operations use for emulating store
149 /// operations of unsupported integer bitwidths, based on the memref
150 /// type. Returns std::nullopt on failure.
151 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
152  auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153  switch (sc.getValue()) {
154  case spirv::StorageClass::StorageBuffer:
155  return spirv::Scope::Device;
156  case spirv::StorageClass::Workgroup:
157  return spirv::Scope::Workgroup;
158  default:
159  break;
160  }
161  return {};
162 }
163 
164 /// Casts the given `srcInt` into a boolean value.
165 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
166  if (srcInt.getType().isInteger(1))
167  return srcInt;
168 
169  auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
170  return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Operation conversion
175 //===----------------------------------------------------------------------===//
176 
177 // Note that DRR cannot be used for the patterns in this file: we may need to
178 // convert type along the way, which requires ConversionPattern. DRR generates
179 // normal RewritePattern.
180 
181 namespace {
182 
183 /// Converts memref.alloca to SPIR-V Function variables.
184 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
185 public:
187 
188  LogicalResult
189  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
190  ConversionPatternRewriter &rewriter) const override;
191 };
192 
193 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
194 /// to Workgroup memory when the size is constant. Note that this pattern needs
195 /// to be applied in a pass that runs at least at spirv.module scope since it
196 /// wil ladd global variables into the spirv.module.
197 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
198 public:
200 
201  LogicalResult
202  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override;
204 };
205 
206 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
207 class AtomicRMWOpPattern final
208  : public OpConversionPattern<memref::AtomicRMWOp> {
209 public:
211 
212  LogicalResult
213  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
214  ConversionPatternRewriter &rewriter) const override;
215 };
216 
217 /// Removed a deallocation if it is a supported allocation. Currently only
218 /// removes deallocation if the memory space is workgroup memory.
219 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
220 public:
222 
223  LogicalResult
224  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
225  ConversionPatternRewriter &rewriter) const override;
226 };
227 
228 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
229 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
230 public:
232 
233  LogicalResult
234  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
235  ConversionPatternRewriter &rewriter) const override;
236 };
237 
238 /// Converts memref.load to spirv.Load + spirv.AccessChain.
239 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
240 public:
242 
243  LogicalResult
244  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
245  ConversionPatternRewriter &rewriter) const override;
246 };
247 
248 /// Converts memref.load to spirv.Image + spirv.ImageFetch
249 class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
250 public:
252 
253  LogicalResult
254  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
255  ConversionPatternRewriter &rewriter) const override;
256 };
257 
258 /// Converts memref.store to spirv.Store on integers.
259 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
260 public:
262 
263  LogicalResult
264  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const override;
266 };
267 
268 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
269 class MemorySpaceCastOpPattern final
270  : public OpConversionPattern<memref::MemorySpaceCastOp> {
271 public:
273 
274  LogicalResult
275  matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
276  ConversionPatternRewriter &rewriter) const override;
277 };
278 
279 /// Converts memref.store to spirv.Store.
280 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
281 public:
283 
284  LogicalResult
285  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const override;
287 };
288 
289 class ReinterpretCastPattern final
290  : public OpConversionPattern<memref::ReinterpretCastOp> {
291 public:
293 
294  LogicalResult
295  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
296  ConversionPatternRewriter &rewriter) const override;
297 };
298 
299 class CastPattern final : public OpConversionPattern<memref::CastOp> {
300 public:
302 
303  LogicalResult
304  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
305  ConversionPatternRewriter &rewriter) const override {
306  Value src = adaptor.getSource();
307  Type srcType = src.getType();
308 
309  const TypeConverter *converter = getTypeConverter();
310  Type dstType = converter->convertType(op.getType());
311  if (srcType != dstType)
312  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
313  diag << "types doesn't match: " << srcType << " and " << dstType;
314  });
315 
316  rewriter.replaceOp(op, src);
317  return success();
318  }
319 };
320 
321 /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
322 class ExtractAlignedPointerAsIndexOpPattern final
323  : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
324 public:
326 
327  LogicalResult
328  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
329  OpAdaptor adaptor,
330  ConversionPatternRewriter &rewriter) const override;
331 };
332 } // namespace
333 
334 //===----------------------------------------------------------------------===//
335 // AllocaOp
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult
339 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const {
341  MemRefType allocType = allocaOp.getType();
342  if (!isAllocationSupported(allocaOp, allocType))
343  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
344 
345  // Get the SPIR-V type for the allocation.
346  Type spirvType = getTypeConverter()->convertType(allocType);
347  if (!spirvType)
348  return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
349 
350  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
351  spirv::StorageClass::Function,
352  /*initializer=*/nullptr);
353  return success();
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // AllocOp
358 //===----------------------------------------------------------------------===//
359 
360 LogicalResult
361 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
362  ConversionPatternRewriter &rewriter) const {
363  MemRefType allocType = operation.getType();
364  if (!isAllocationSupported(operation, allocType))
365  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
366 
367  // Get the SPIR-V type for the allocation.
368  Type spirvType = getTypeConverter()->convertType(allocType);
369  if (!spirvType)
370  return rewriter.notifyMatchFailure(operation, "type conversion failed");
371 
372  // Insert spirv.GlobalVariable for this allocation.
373  Operation *parent =
374  SymbolTable::getNearestSymbolTable(operation->getParentOp());
375  if (!parent)
376  return failure();
377  Location loc = operation.getLoc();
378  spirv::GlobalVariableOp varOp;
379  {
380  OpBuilder::InsertionGuard guard(rewriter);
381  Block &entryBlock = *parent->getRegion(0).begin();
382  rewriter.setInsertionPointToStart(&entryBlock);
383  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
384  std::string varName =
385  std::string("__workgroup_mem__") +
386  std::to_string(std::distance(varOps.begin(), varOps.end()));
387  varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
388  /*initializer=*/nullptr);
389  }
390 
391  // Get pointer to global variable at the current scope.
392  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
393  return success();
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AllocOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult
401 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
402  OpAdaptor adaptor,
403  ConversionPatternRewriter &rewriter) const {
404  if (isa<FloatType>(atomicOp.getType()))
405  return rewriter.notifyMatchFailure(atomicOp,
406  "unimplemented floating-point case");
407 
408  auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
409  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
410  if (!scope)
411  return rewriter.notifyMatchFailure(atomicOp,
412  "unsupported memref memory space");
413 
414  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
415  Type resultType = typeConverter.convertType(atomicOp.getType());
416  if (!resultType)
417  return rewriter.notifyMatchFailure(atomicOp,
418  "failed to convert result type");
419 
420  auto loc = atomicOp.getLoc();
421  Value ptr =
422  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
423  adaptor.getIndices(), loc, rewriter);
424 
425  if (!ptr)
426  return failure();
427 
428 #define ATOMIC_CASE(kind, spirvOp) \
429  case arith::AtomicRMWKind::kind: \
430  rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
431  atomicOp, resultType, ptr, *scope, \
432  spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
433  break
434 
435  switch (atomicOp.getKind()) {
436  ATOMIC_CASE(addi, AtomicIAddOp);
437  ATOMIC_CASE(maxs, AtomicSMaxOp);
438  ATOMIC_CASE(maxu, AtomicUMaxOp);
439  ATOMIC_CASE(mins, AtomicSMinOp);
440  ATOMIC_CASE(minu, AtomicUMinOp);
441  ATOMIC_CASE(ori, AtomicOrOp);
442  ATOMIC_CASE(andi, AtomicAndOp);
443  default:
444  return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
445  }
446 
447 #undef ATOMIC_CASE
448 
449  return success();
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // DeallocOp
454 //===----------------------------------------------------------------------===//
455 
456 LogicalResult
457 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
458  OpAdaptor adaptor,
459  ConversionPatternRewriter &rewriter) const {
460  MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
461  if (!isAllocationSupported(operation, deallocType))
462  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
463  rewriter.eraseOp(operation);
464  return success();
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // LoadOp
469 //===----------------------------------------------------------------------===//
470 
472  spirv::MemoryAccessAttr memoryAccess;
473  IntegerAttr alignment;
474 };
475 
476 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
477 /// any.
478 static FailureOr<MemoryRequirements>
479 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
480  uint64_t preferredAlignment) {
481  if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
482  return failure();
483  }
484 
485  MLIRContext *ctx = accessedPtr.getContext();
486 
487  auto memoryAccess = spirv::MemoryAccess::None;
488  if (isNontemporal) {
489  memoryAccess = spirv::MemoryAccess::Nontemporal;
490  }
491 
492  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
493  bool mayOmitAlignment =
494  !preferredAlignment &&
495  ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
496  if (mayOmitAlignment) {
497  if (memoryAccess == spirv::MemoryAccess::None) {
498  return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
499  }
500  return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
501  IntegerAttr{}};
502  }
503 
504  // PhysicalStorageBuffers require the `Aligned` attribute.
505  // Other storage types may show an `Aligned` attribute.
506  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
507  if (!pointeeType)
508  return failure();
509 
510  // For scalar types, the alignment is determined by their size.
511  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
512  if (!sizeInBytes.has_value())
513  return failure();
514 
515  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
516  auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
517  auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
519  return MemoryRequirements{memAccessAttr, alignment};
520 }
521 
522 /// Given an accessed SPIR-V pointer and the original memref load/store
523 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
524 /// account the alignment attributes applied to the load/store op.
525 template <class LoadOrStoreOp>
526 static FailureOr<MemoryRequirements>
527 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
528  static_assert(
529  llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
530  "Must be called on either memref::LoadOp or memref::StoreOp");
531 
532  return calculateMemoryRequirements(accessedPtr,
533  loadOrStoreOp.getNontemporal(),
534  loadOrStoreOp.getAlignment().value_or(0));
535 }
536 
537 LogicalResult
538 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
539  ConversionPatternRewriter &rewriter) const {
540  auto loc = loadOp.getLoc();
541  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
542  if (!memrefType.getElementType().isSignlessInteger())
543  return failure();
544 
545  auto memorySpaceAttr =
546  dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
547  if (!memorySpaceAttr)
548  return rewriter.notifyMatchFailure(
549  loadOp, "missing memory space SPIR-V storage class attribute");
550 
551  if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
552  return rewriter.notifyMatchFailure(
553  loadOp,
554  "failed to lower memref in image storage class to storage buffer");
555 
556  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
557  Value accessChain =
558  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
559  adaptor.getIndices(), loc, rewriter);
560 
561  if (!accessChain)
562  return failure();
563 
564  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
565  bool isBool = srcBits == 1;
566  if (isBool)
567  srcBits = typeConverter.getOptions().boolNumBits;
568 
569  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
570  if (!pointerType)
571  return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
572 
573  Type pointeeType = pointerType.getPointeeType();
574  Type dstType;
575  if (typeConverter.allows(spirv::Capability::Kernel)) {
576  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
577  dstType = arrayType.getElementType();
578  else
579  dstType = pointeeType;
580  } else {
581  // For Vulkan we need to extract element from wrapping struct and array.
582  Type structElemType =
583  cast<spirv::StructType>(pointeeType).getElementType(0);
584  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
585  dstType = arrayType.getElementType();
586  else
587  dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
588  }
589  int dstBits = dstType.getIntOrFloatBitWidth();
590  assert(dstBits % srcBits == 0);
591 
592  // If the rewritten load op has the same bit width, use the loading value
593  // directly.
594  if (srcBits == dstBits) {
595  auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
596  if (failed(memoryRequirements))
597  return rewriter.notifyMatchFailure(
598  loadOp, "failed to determine memory requirements");
599 
600  auto [memoryAccess, alignment] = *memoryRequirements;
601  Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
602  memoryAccess, alignment);
603  if (isBool)
604  loadVal = castIntNToBool(loc, loadVal, rewriter);
605  rewriter.replaceOp(loadOp, loadVal);
606  return success();
607  }
608 
609  // Bitcasting is currently unsupported for Kernel capability /
610  // spirv.PtrAccessChain.
611  if (typeConverter.allows(spirv::Capability::Kernel))
612  return failure();
613 
614  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
615  if (!accessChainOp)
616  return failure();
617 
618  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
619  // still returns a linearized accessing. If the accessing is not linearized,
620  // there will be offset issues.
621  assert(accessChainOp.getIndices().size() == 2);
622  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
623  srcBits, dstBits, rewriter);
624  auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
625  if (failed(memoryRequirements))
626  return rewriter.notifyMatchFailure(
627  loadOp, "failed to determine memory requirements");
628 
629  auto [memoryAccess, alignment] = *memoryRequirements;
630  Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
631  memoryAccess, alignment);
632 
633  // Shift the bits to the rightmost.
634  // ____XXXX________ -> ____________XXXX
635  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
636  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
637  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
638  loc, spvLoadOp.getType(), spvLoadOp, offset);
639 
640  // Apply the mask to extract corresponding bits.
641  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
642  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
643  result =
644  rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
645 
646  // Apply sign extension on the loading value unconditionally. The signedness
647  // semantic is carried in the operator itself, we relies other pattern to
648  // handle the casting.
649  IntegerAttr shiftValueAttr =
650  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
651  Value shiftValue =
652  rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
653  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
654  result, shiftValue);
655  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
656  loc, dstType, result, shiftValue);
657 
658  rewriter.replaceOp(loadOp, result);
659 
660  assert(accessChainOp.use_empty());
661  rewriter.eraseOp(accessChainOp);
662 
663  return success();
664 }
665 
666 LogicalResult
667 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
668  ConversionPatternRewriter &rewriter) const {
669  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
670  if (memrefType.getElementType().isSignlessInteger())
671  return failure();
672 
673  auto memorySpaceAttr =
674  dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
675  if (!memorySpaceAttr)
676  return rewriter.notifyMatchFailure(
677  loadOp, "missing memory space SPIR-V storage class attribute");
678 
679  if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
680  return rewriter.notifyMatchFailure(
681  loadOp,
682  "failed to lower memref in image storage class to storage buffer");
683 
684  Value loadPtr = spirv::getElementPtr(
685  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
686  adaptor.getIndices(), loadOp.getLoc(), rewriter);
687 
688  if (!loadPtr)
689  return failure();
690 
691  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
692  if (failed(memoryRequirements))
693  return rewriter.notifyMatchFailure(
694  loadOp, "failed to determine memory requirements");
695 
696  auto [memoryAccess, alignment] = *memoryRequirements;
697  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
698  alignment);
699  return success();
700 }
701 
702 LogicalResult
703 ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
704  ConversionPatternRewriter &rewriter) const {
705  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
706 
707  auto memorySpaceAttr =
708  dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
709  if (!memorySpaceAttr)
710  return rewriter.notifyMatchFailure(
711  loadOp, "missing memory space SPIR-V storage class attribute");
712 
713  if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
714  return rewriter.notifyMatchFailure(
715  loadOp, "failed to lower memref in non-image storage class to image");
716 
717  Value loadPtr = adaptor.getMemref();
718  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
719  if (failed(memoryRequirements))
720  return rewriter.notifyMatchFailure(
721  loadOp, "failed to determine memory requirements");
722 
723  const auto [memoryAccess, alignment] = *memoryRequirements;
724 
725  if (!loadOp.getMemRefType().hasRank())
726  return rewriter.notifyMatchFailure(
727  loadOp, "cannot lower unranked memrefs to SPIR-V images");
728 
729  // We currently only support lowering of scalar memref elements to texels in
730  // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
731  // elements to texels in richer formats.
732  if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
733  return rewriter.notifyMatchFailure(
734  loadOp,
735  "cannot lower memrefs who's element type is not a SPIR-V scalar type"
736  "to SPIR-V images");
737 
738  // We currently only support sampled images since OpImageFetch does not work
739  // for plain images and the OpImageRead instruction needs to be materialized
740  // instead or texels need to be accessed via atomics through a texel pointer.
741  // Future work will generalize support to plain images.
742  auto convertedPointeeType = cast<spirv::PointerType>(
743  getTypeConverter()->convertType(loadOp.getMemRefType()));
744  if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
745  return rewriter.notifyMatchFailure(loadOp,
746  "cannot lower memrefs which do not "
747  "convert to SPIR-V sampled images");
748 
749  // Materialize the lowering.
750  Location loc = loadOp->getLoc();
751  auto imageLoadOp =
752  spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
753  // Extract the image from the sampled image.
754  auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
755 
756  // Build a vector of coordinates or just a scalar index if we have a 1D image.
757  Value coords;
758  if (memrefType.getRank() != 1) {
759  auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
760  adaptor.getIndices().getType()[0]);
761  coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
762  adaptor.getIndices());
763  } else {
764  coords = adaptor.getIndices()[0];
765  }
766 
767  // Fetch the value out of the image.
768  auto resultVectorType = VectorType::get({4}, loadOp.getType());
769  auto fetchOp = spirv::ImageFetchOp::create(
770  rewriter, loc, resultVectorType, imageOp, coords,
771  mlir::spirv::ImageOperandsAttr{}, ValueRange{});
772 
773  // Note that because OpImageFetch returns a rank 4 vector we need to extract
774  // the elements corresponding to the load which will since we only support the
775  // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
776  auto compositeExtractOp =
777  spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
778 
779  rewriter.replaceOp(loadOp, compositeExtractOp);
780  return success();
781 }
782 
783 LogicalResult
784 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
785  ConversionPatternRewriter &rewriter) const {
786  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
787  if (!memrefType.getElementType().isSignlessInteger())
788  return rewriter.notifyMatchFailure(storeOp,
789  "element type is not a signless int");
790 
791  auto loc = storeOp.getLoc();
792  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
793  Value accessChain =
794  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
795  adaptor.getIndices(), loc, rewriter);
796 
797  if (!accessChain)
798  return rewriter.notifyMatchFailure(
799  storeOp, "failed to convert element pointer type");
800 
801  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
802 
803  bool isBool = srcBits == 1;
804  if (isBool)
805  srcBits = typeConverter.getOptions().boolNumBits;
806 
807  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
808  if (!pointerType)
809  return rewriter.notifyMatchFailure(storeOp,
810  "failed to convert memref type");
811 
812  Type pointeeType = pointerType.getPointeeType();
813  IntegerType dstType;
814  if (typeConverter.allows(spirv::Capability::Kernel)) {
815  if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
816  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
817  else
818  dstType = dyn_cast<IntegerType>(pointeeType);
819  } else {
820  // For Vulkan we need to extract element from wrapping struct and array.
821  Type structElemType =
822  cast<spirv::StructType>(pointeeType).getElementType(0);
823  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
824  dstType = dyn_cast<IntegerType>(arrayType.getElementType());
825  else
826  dstType = dyn_cast<IntegerType>(
827  cast<spirv::RuntimeArrayType>(structElemType).getElementType());
828  }
829 
830  if (!dstType)
831  return rewriter.notifyMatchFailure(
832  storeOp, "failed to determine destination element type");
833 
834  int dstBits = static_cast<int>(dstType.getWidth());
835  assert(dstBits % srcBits == 0);
836 
837  if (srcBits == dstBits) {
838  auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
839  if (failed(memoryRequirements))
840  return rewriter.notifyMatchFailure(
841  storeOp, "failed to determine memory requirements");
842 
843  auto [memoryAccess, alignment] = *memoryRequirements;
844  Value storeVal = adaptor.getValue();
845  if (isBool)
846  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
847  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
848  memoryAccess, alignment);
849  return success();
850  }
851 
852  // Bitcasting is currently unsupported for Kernel capability /
853  // spirv.PtrAccessChain.
854  if (typeConverter.allows(spirv::Capability::Kernel))
855  return failure();
856 
857  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
858  if (!accessChainOp)
859  return failure();
860 
861  // Since there are multiple threads in the processing, the emulation will be
862  // done with atomic operations. E.g., if the stored value is i8, rewrite the
863  // StoreOp to:
864  // 1) load a 32-bit integer
865  // 2) clear 8 bits in the loaded value
866  // 3) set 8 bits in the loaded value
867  // 4) store 32-bit value back
868  //
869  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
870  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
871  // step.
872  assert(accessChainOp.getIndices().size() == 2);
873  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
874  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
875 
876  // Create a mask to clear the destination. E.g., if it is the second i8 in
877  // i32, 0xFFFF00FF is created.
878  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
879  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
880  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
881  loc, dstType, mask, offset);
882  clearBitsMask =
883  rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
884 
885  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
886  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
887  srcBits, dstBits, rewriter);
888  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
889  if (!scope)
890  return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
891 
892  Value result = spirv::AtomicAndOp::create(
893  rewriter, loc, dstType, adjustedPtr, *scope,
894  spirv::MemorySemantics::AcquireRelease, clearBitsMask);
895  result = spirv::AtomicOrOp::create(
896  rewriter, loc, dstType, adjustedPtr, *scope,
897  spirv::MemorySemantics::AcquireRelease, storeVal);
898 
899  // The AtomicOrOp has no side effect. Since it is already inserted, we can
900  // just remove the original StoreOp. Note that rewriter.replaceOp()
901  // doesn't work because it only accepts that the numbers of result are the
902  // same.
903  rewriter.eraseOp(storeOp);
904 
905  assert(accessChainOp.use_empty());
906  rewriter.eraseOp(accessChainOp);
907 
908  return success();
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // MemorySpaceCastOp
913 //===----------------------------------------------------------------------===//
914 
915 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
916  memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
917  ConversionPatternRewriter &rewriter) const {
918  Location loc = addrCastOp.getLoc();
919  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
920  if (!typeConverter.allows(spirv::Capability::Kernel))
921  return rewriter.notifyMatchFailure(
922  loc, "address space casts require kernel capability");
923 
924  auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
925  if (!sourceType)
926  return rewriter.notifyMatchFailure(
927  loc, "SPIR-V lowering requires ranked memref types");
928  auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
929 
930  auto sourceStorageClassAttr =
931  dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
932  if (!sourceStorageClassAttr)
933  return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
934  diag << "source address space " << sourceType.getMemorySpace()
935  << " must be a SPIR-V storage class";
936  });
937  auto resultStorageClassAttr =
938  dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
939  if (!resultStorageClassAttr)
940  return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
941  diag << "result address space " << resultType.getMemorySpace()
942  << " must be a SPIR-V storage class";
943  });
944 
945  spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
946  spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
947 
948  Value result = adaptor.getSource();
949  Type resultPtrType = typeConverter.convertType(resultType);
950  if (!resultPtrType)
951  return rewriter.notifyMatchFailure(addrCastOp,
952  "failed to convert memref type");
953 
954  Type genericPtrType = resultPtrType;
955  // SPIR-V doesn't have a general address space cast operation. Instead, it has
956  // conversions to and from generic pointers. To implement the general case,
957  // we use specific-to-generic conversions when the source class is not
958  // generic. Then when the result storage class is not generic, we convert the
959  // generic pointer (either the input on ar intermediate result) to that
960  // class. This also means that we'll need the intermediate generic pointer
961  // type if neither the source or destination have it.
962  if (sourceSc != spirv::StorageClass::Generic &&
963  resultSc != spirv::StorageClass::Generic) {
964  Type intermediateType =
965  MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
966  sourceType.getLayout(),
967  rewriter.getAttr<spirv::StorageClassAttr>(
968  spirv::StorageClass::Generic));
969  genericPtrType = typeConverter.convertType(intermediateType);
970  }
971  if (sourceSc != spirv::StorageClass::Generic) {
972  result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
973  result);
974  }
975  if (resultSc != spirv::StorageClass::Generic) {
976  result =
977  spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
978  }
979  rewriter.replaceOp(addrCastOp, result);
980  return success();
981 }
982 
983 LogicalResult
984 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
985  ConversionPatternRewriter &rewriter) const {
986  auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
987  if (memrefType.getElementType().isSignlessInteger())
988  return rewriter.notifyMatchFailure(storeOp, "signless int");
989  auto storePtr = spirv::getElementPtr(
990  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
991  adaptor.getIndices(), storeOp.getLoc(), rewriter);
992 
993  if (!storePtr)
994  return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
995 
996  auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
997  if (failed(memoryRequirements))
998  return rewriter.notifyMatchFailure(
999  storeOp, "failed to determine memory requirements");
1000 
1001  auto [memoryAccess, alignment] = *memoryRequirements;
1002  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1003  storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1004  return success();
1005 }
1006 
1007 LogicalResult ReinterpretCastPattern::matchAndRewrite(
1008  memref::ReinterpretCastOp op, OpAdaptor adaptor,
1009  ConversionPatternRewriter &rewriter) const {
1010  Value src = adaptor.getSource();
1011  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
1012 
1013  if (!srcType)
1014  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1015  diag << "invalid src type " << src.getType();
1016  });
1017 
1018  const TypeConverter *converter = getTypeConverter();
1019 
1020  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1021  if (dstType != srcType)
1022  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1023  diag << "invalid dst type " << op.getType();
1024  });
1025 
1026  OpFoldResult offset =
1027  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1028  .front();
1029  if (isZeroInteger(offset)) {
1030  rewriter.replaceOp(op, src);
1031  return success();
1032  }
1033 
1034  Type intType = converter->convertType(rewriter.getIndexType());
1035  if (!intType)
1036  return rewriter.notifyMatchFailure(op, "failed to convert index type");
1037 
1038  Location loc = op.getLoc();
1039  auto offsetValue = [&]() -> Value {
1040  if (auto val = dyn_cast<Value>(offset))
1041  return val;
1042 
1043  int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1044  Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1045  return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1046  }();
1047 
1048  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1049  op, src, offsetValue, ValueRange());
1050  return success();
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // ExtractAlignedPointerAsIndexOp
1055 //===----------------------------------------------------------------------===//
1056 
1057 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1058  memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1059  ConversionPatternRewriter &rewriter) const {
1060  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1061  Type indexType = typeConverter.getIndexType();
1062  rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1063  adaptor.getSource());
1064  return success();
1065 }
1066 
1067 //===----------------------------------------------------------------------===//
1068 // Pattern population
1069 //===----------------------------------------------------------------------===//
1070 
1071 namespace mlir {
1074  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1075  DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1076  IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1077  StoreOpPattern, ReinterpretCastPattern, CastPattern,
1078  ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1079  patterns.getContext());
1080 }
1081 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Type getElementType(Type type)
Determine the element type of type.
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
@ None
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
IndexType getIndexType()
Definition: Builders.cpp:50
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess