MLIR  16.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "memref-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Utility functions
26 //===----------------------------------------------------------------------===//
27 
28 /// Returns the offset of the value in `targetBits` representation.
29 ///
30 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
31 /// It's assumed to be non-negative.
32 ///
33 /// When accessing an element in the array treating as having elements of
34 /// `targetBits`, multiple values are loaded in the same time. The method
35 /// returns the offset where the `srcIdx` locates in the value. For example, if
36 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
37 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
38 /// element has 8 bits.
39 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
40  int targetBits, OpBuilder &builder) {
41  assert(targetBits % sourceBits == 0);
42  IntegerType targetType = builder.getIntegerType(targetBits);
43  IntegerAttr idxAttr =
44  builder.getIntegerAttr(targetType, targetBits / sourceBits);
45  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
46  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
47  auto srcBitsValue =
48  builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
49  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
50  return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
51 }
52 
53 /// Returns an adjusted spirv::AccessChainOp. Based on the
54 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
55 /// supported. During conversion if a memref of an unsupported type is used,
56 /// load/stores to this memref need to be modified to use a supported higher
57 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
58 /// 1D array (spirv.array or spirv.rt_array), the last index is modified to load
59 /// the bits needed. The extraction of the actual bits needed are handled
60 /// separately. Note that this only works for a 1-D tensor.
62  spirv::AccessChainOp op,
63  int sourceBits, int targetBits,
64  OpBuilder &builder) {
65  assert(targetBits % sourceBits == 0);
66  const auto loc = op.getLoc();
67  IntegerType targetType = builder.getIntegerType(targetBits);
68  IntegerAttr attr =
69  builder.getIntegerAttr(targetType, targetBits / sourceBits);
70  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
71  auto lastDim = op->getOperand(op.getNumOperands() - 1);
72  auto indices = llvm::to_vector<4>(op.getIndices());
73  // There are two elements if this is a 1-D tensor.
74  assert(indices.size() == 2);
75  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
76  Type t = typeConverter.convertType(op.getComponentPtr().getType());
77  return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
78 }
79 
80 /// Returns the shifted `targetBits`-bit value with the given offset.
81 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
82  int targetBits, OpBuilder &builder) {
83  Type targetType = builder.getIntegerType(targetBits);
84  Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
85  return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
86  offset);
87 }
88 
89 /// Returns true if the allocations of memref `type` generated from `allocOp`
90 /// can be lowered to SPIR-V.
91 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
92  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
93  auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
94  if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
95  return false;
96  } else if (isa<memref::AllocaOp>(allocOp)) {
97  auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
98  if (!sc || sc.getValue() != spirv::StorageClass::Function)
99  return false;
100  } else {
101  return false;
102  }
103 
104  // Currently only support static shape and int or float or vector of int or
105  // float element type.
106  if (!type.hasStaticShape())
107  return false;
108 
109  Type elementType = type.getElementType();
110  if (auto vecType = elementType.dyn_cast<VectorType>())
111  elementType = vecType.getElementType();
112  return elementType.isIntOrFloat();
113 }
114 
115 /// Returns the scope to use for atomic operations use for emulating store
116 /// operations of unsupported integer bitwidths, based on the memref
117 /// type. Returns None on failure.
118 static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
119  auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
120  switch (sc.getValue()) {
121  case spirv::StorageClass::StorageBuffer:
122  return spirv::Scope::Device;
123  case spirv::StorageClass::Workgroup:
124  return spirv::Scope::Workgroup;
125  default:
126  break;
127  }
128  return {};
129 }
130 
131 /// Casts the given `srcInt` into a boolean value.
132 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
133  if (srcInt.getType().isInteger(1))
134  return srcInt;
135 
136  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
137  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
138 }
139 
140 /// Casts the given `srcBool` into an integer of `dstType`.
141 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
142  OpBuilder &builder) {
143  assert(srcBool.getType().isInteger(1));
144  if (dstType.isInteger(1))
145  return srcBool;
146  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
147  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
148  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Operation conversion
153 //===----------------------------------------------------------------------===//
154 
155 // Note that DRR cannot be used for the patterns in this file: we may need to
156 // convert type along the way, which requires ConversionPattern. DRR generates
157 // normal RewritePattern.
158 
159 namespace {
160 
161 /// Converts memref.alloca to SPIR-V Function variables.
162 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
163 public:
165 
167  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
168  ConversionPatternRewriter &rewriter) const override;
169 };
170 
171 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
172 /// to Workgroup memory when the size is constant. Note that this pattern needs
173 /// to be applied in a pass that runs at least at spirv.module scope since it
174 /// wil ladd global variables into the spirv.module.
175 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
176 public:
178 
180  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 /// Removed a deallocation if it is a supported allocation. Currently only
185 /// removes deallocation if the memory space is workgroup memory.
186 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
187 public:
189 
191  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
192  ConversionPatternRewriter &rewriter) const override;
193 };
194 
195 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
196 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
197 public:
199 
201  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override;
203 };
204 
205 /// Converts memref.load to spirv.Load + spirv.AccessChain.
206 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
207 public:
209 
211  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
212  ConversionPatternRewriter &rewriter) const override;
213 };
214 
215 /// Converts memref.store to spirv.Store on integers.
216 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
217 public:
219 
221  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
222  ConversionPatternRewriter &rewriter) const override;
223 };
224 
225 /// Converts memref.store to spirv.Store.
226 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
227 public:
229 
231  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
232  ConversionPatternRewriter &rewriter) const override;
233 };
234 
235 } // namespace
236 
237 //===----------------------------------------------------------------------===//
238 // AllocaOp
239 //===----------------------------------------------------------------------===//
240 
242 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
243  ConversionPatternRewriter &rewriter) const {
244  MemRefType allocType = allocaOp.getType();
245  if (!isAllocationSupported(allocaOp, allocType))
246  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
247 
248  // Get the SPIR-V type for the allocation.
249  Type spirvType = getTypeConverter()->convertType(allocType);
250  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
251  spirv::StorageClass::Function,
252  /*initializer=*/nullptr);
253  return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // AllocOp
258 //===----------------------------------------------------------------------===//
259 
261 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
262  ConversionPatternRewriter &rewriter) const {
263  MemRefType allocType = operation.getType();
264  if (!isAllocationSupported(operation, allocType))
265  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
266 
267  // Get the SPIR-V type for the allocation.
268  Type spirvType = getTypeConverter()->convertType(allocType);
269 
270  // Insert spirv.GlobalVariable for this allocation.
271  Operation *parent =
272  SymbolTable::getNearestSymbolTable(operation->getParentOp());
273  if (!parent)
274  return failure();
275  Location loc = operation.getLoc();
276  spirv::GlobalVariableOp varOp;
277  {
278  OpBuilder::InsertionGuard guard(rewriter);
279  Block &entryBlock = *parent->getRegion(0).begin();
280  rewriter.setInsertionPointToStart(&entryBlock);
281  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
282  std::string varName =
283  std::string("__workgroup_mem__") +
284  std::to_string(std::distance(varOps.begin(), varOps.end()));
285  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
286  /*initializer=*/nullptr);
287  }
288 
289  // Get pointer to global variable at the current scope.
290  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
291  return success();
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // DeallocOp
296 //===----------------------------------------------------------------------===//
297 
299 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
300  OpAdaptor adaptor,
301  ConversionPatternRewriter &rewriter) const {
302  MemRefType deallocType = operation.getMemref().getType().cast<MemRefType>();
303  if (!isAllocationSupported(operation, deallocType))
304  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
305  rewriter.eraseOp(operation);
306  return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // LoadOp
311 //===----------------------------------------------------------------------===//
312 
314 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
315  ConversionPatternRewriter &rewriter) const {
316  auto loc = loadOp.getLoc();
317  auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
318  if (!memrefType.getElementType().isSignlessInteger())
319  return failure();
320 
321  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
322  Value accessChain =
323  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
324  adaptor.getIndices(), loc, rewriter);
325 
326  if (!accessChain)
327  return failure();
328 
329  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
330  bool isBool = srcBits == 1;
331  if (isBool)
332  srcBits = typeConverter.getOptions().boolNumBits;
333  Type pointeeType = typeConverter.convertType(memrefType)
335  .getPointeeType();
336  Type dstType;
337  if (typeConverter.allows(spirv::Capability::Kernel)) {
338  if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
339  dstType = arrayType.getElementType();
340  else
341  dstType = pointeeType;
342  } else {
343  // For Vulkan we need to extract element from wrapping struct and array.
344  Type structElemType =
345  pointeeType.cast<spirv::StructType>().getElementType(0);
346  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
347  dstType = arrayType.getElementType();
348  else
349  dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
350  }
351  int dstBits = dstType.getIntOrFloatBitWidth();
352  assert(dstBits % srcBits == 0);
353 
354  // If the rewrited load op has the same bit width, use the loading value
355  // directly.
356  if (srcBits == dstBits) {
357  Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
358  if (isBool)
359  loadVal = castIntNToBool(loc, loadVal, rewriter);
360  rewriter.replaceOp(loadOp, loadVal);
361  return success();
362  }
363 
364  // Bitcasting is currently unsupported for Kernel capability /
365  // spirv.PtrAccessChain.
366  if (typeConverter.allows(spirv::Capability::Kernel))
367  return failure();
368 
369  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
370  if (!accessChainOp)
371  return failure();
372 
373  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
374  // still returns a linearized accessing. If the accessing is not linearized,
375  // there will be offset issues.
376  assert(accessChainOp.getIndices().size() == 2);
377  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
378  srcBits, dstBits, rewriter);
379  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
380  loc, dstType, adjustedPtr,
381  loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
382  spirv::attributeName<spirv::MemoryAccess>()),
383  loadOp->getAttrOfType<IntegerAttr>("alignment"));
384 
385  // Shift the bits to the rightmost.
386  // ____XXXX________ -> ____________XXXX
387  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
388  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
389  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
390  loc, spvLoadOp.getType(), spvLoadOp, offset);
391 
392  // Apply the mask to extract corresponding bits.
393  Value mask = rewriter.create<spirv::ConstantOp>(
394  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
395  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
396 
397  // Apply sign extension on the loading value unconditionally. The signedness
398  // semantic is carried in the operator itself, we relies other pattern to
399  // handle the casting.
400  IntegerAttr shiftValueAttr =
401  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
402  Value shiftValue =
403  rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
404  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
405  shiftValue);
406  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
407  shiftValue);
408 
409  if (isBool) {
410  dstType = typeConverter.convertType(loadOp.getType());
411  mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
412  result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
413  } else if (result.getType().getIntOrFloatBitWidth() !=
414  static_cast<unsigned>(dstBits)) {
415  result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
416  }
417  rewriter.replaceOp(loadOp, result);
418 
419  assert(accessChainOp.use_empty());
420  rewriter.eraseOp(accessChainOp);
421 
422  return success();
423 }
424 
426 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
427  ConversionPatternRewriter &rewriter) const {
428  auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
429  if (memrefType.getElementType().isSignlessInteger())
430  return failure();
431  auto loadPtr = spirv::getElementPtr(
432  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
433  adaptor.getIndices(), loadOp.getLoc(), rewriter);
434 
435  if (!loadPtr)
436  return failure();
437 
438  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
439  return success();
440 }
441 
443 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const {
445  auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
446  if (!memrefType.getElementType().isSignlessInteger())
447  return failure();
448 
449  auto loc = storeOp.getLoc();
450  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
451  Value accessChain =
452  spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
453  adaptor.getIndices(), loc, rewriter);
454 
455  if (!accessChain)
456  return failure();
457 
458  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
459 
460  bool isBool = srcBits == 1;
461  if (isBool)
462  srcBits = typeConverter.getOptions().boolNumBits;
463 
464  Type pointeeType = typeConverter.convertType(memrefType)
466  .getPointeeType();
467  Type dstType;
468  if (typeConverter.allows(spirv::Capability::Kernel)) {
469  if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
470  dstType = arrayType.getElementType();
471  else
472  dstType = pointeeType;
473  } else {
474  // For Vulkan we need to extract element from wrapping struct and array.
475  Type structElemType =
476  pointeeType.cast<spirv::StructType>().getElementType(0);
477  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
478  dstType = arrayType.getElementType();
479  else
480  dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
481  }
482 
483  int dstBits = dstType.getIntOrFloatBitWidth();
484  assert(dstBits % srcBits == 0);
485 
486  if (srcBits == dstBits) {
487  Value storeVal = adaptor.getValue();
488  if (isBool)
489  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
490  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
491  return success();
492  }
493 
494  // Bitcasting is currently unsupported for Kernel capability /
495  // spirv.PtrAccessChain.
496  if (typeConverter.allows(spirv::Capability::Kernel))
497  return failure();
498 
499  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
500  if (!accessChainOp)
501  return failure();
502 
503  // Since there are multi threads in the processing, the emulation will be done
504  // with atomic operations. E.g., if the storing value is i8, rewrite the
505  // StoreOp to
506  // 1) load a 32-bit integer
507  // 2) clear 8 bits in the loading value
508  // 3) store 32-bit value back
509  // 4) load a 32-bit integer
510  // 5) modify 8 bits in the loading value
511  // 6) store 32-bit value back
512  // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
513  // 4 to step 6 are done by AtomicOr as another atomic step.
514  assert(accessChainOp.getIndices().size() == 2);
515  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
516  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
517 
518  // Create a mask to clear the destination. E.g., if it is the second i8 in
519  // i32, 0xFFFF00FF is created.
520  Value mask = rewriter.create<spirv::ConstantOp>(
521  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
522  Value clearBitsMask =
523  rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
524  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
525 
526  Value storeVal = adaptor.getValue();
527  if (isBool)
528  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
529  storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
530  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
531  srcBits, dstBits, rewriter);
532  Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
533  if (!scope)
534  return failure();
535  Value result = rewriter.create<spirv::AtomicAndOp>(
536  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
537  clearBitsMask);
538  result = rewriter.create<spirv::AtomicOrOp>(
539  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
540  storeVal);
541 
542  // The AtomicOrOp has no side effect. Since it is already inserted, we can
543  // just remove the original StoreOp. Note that rewriter.replaceOp()
544  // doesn't work because it only accepts that the numbers of result are the
545  // same.
546  rewriter.eraseOp(storeOp);
547 
548  assert(accessChainOp.use_empty());
549  rewriter.eraseOp(accessChainOp);
550 
551  return success();
552 }
553 
555 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
556  ConversionPatternRewriter &rewriter) const {
557  auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
558  if (memrefType.getElementType().isSignlessInteger())
559  return failure();
560  auto storePtr = spirv::getElementPtr(
561  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
562  adaptor.getIndices(), storeOp.getLoc(), rewriter);
563 
564  if (!storePtr)
565  return failure();
566 
567  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
568  adaptor.getValue());
569  return success();
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // Pattern population
574 //===----------------------------------------------------------------------===//
575 
576 namespace mlir {
578  RewritePatternSet &patterns) {
579  patterns
580  .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
581  IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
582  typeConverter, patterns.getContext());
583 }
584 } // namespace mlir
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static constexpr const bool value
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, int targetBits, OpBuilder &builder)
Returns the shifted targetBits-bit value with the given offset.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
Block represents an ordered list of Operations.
Definition: Block.h:30
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:182
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
U cast() const
Definition: Location.h:90
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
iterator begin()
Definition: Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:33
U dyn_cast() const
Definition: Types.h:270
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:89
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SPIR-V struct type.
Definition: SPIRVTypes.h:281
Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26