MLIR  14.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "memref-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Utility functions
25 //===----------------------------------------------------------------------===//
26 
27 /// Returns the offset of the value in `targetBits` representation.
28 ///
29 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
30 /// It's assumed to be non-negative.
31 ///
32 /// When accessing an element in the array treating as having elements of
33 /// `targetBits`, multiple values are loaded in the same time. The method
34 /// returns the offset where the `srcIdx` locates in the value. For example, if
35 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
36 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
37 /// element has 8 bits.
38 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
39  int targetBits, OpBuilder &builder) {
40  assert(targetBits % sourceBits == 0);
41  IntegerType targetType = builder.getIntegerType(targetBits);
42  IntegerAttr idxAttr =
43  builder.getIntegerAttr(targetType, targetBits / sourceBits);
44  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
45  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
46  auto srcBitsValue =
47  builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
48  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
49  return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
50 }
51 
52 /// Returns an adjusted spirv::AccessChainOp. Based on the
53 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
54 /// supported. During conversion if a memref of an unsupported type is used,
55 /// load/stores to this memref need to be modified to use a supported higher
56 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
57 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
58 /// bits needed. The extraction of the actual bits needed are handled
59 /// separately. Note that this only works for a 1-D tensor.
61  spirv::AccessChainOp op,
62  int sourceBits, int targetBits,
63  OpBuilder &builder) {
64  assert(targetBits % sourceBits == 0);
65  const auto loc = op.getLoc();
66  IntegerType targetType = builder.getIntegerType(targetBits);
67  IntegerAttr attr =
68  builder.getIntegerAttr(targetType, targetBits / sourceBits);
69  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
70  auto lastDim = op->getOperand(op.getNumOperands() - 1);
71  auto indices = llvm::to_vector<4>(op.indices());
72  // There are two elements if this is a 1-D tensor.
73  assert(indices.size() == 2);
74  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
75  Type t = typeConverter.convertType(op.component_ptr().getType());
76  return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
77 }
78 
79 /// Returns the shifted `targetBits`-bit value with the given offset.
80 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
81  int targetBits, OpBuilder &builder) {
82  Type targetType = builder.getIntegerType(targetBits);
83  Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
84  return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
85  offset);
86 }
87 
88 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
89 static bool isAllocationSupported(MemRefType t) {
90  // Currently only support workgroup local memory allocations with static
91  // shape and int or float or vector of int or float element type.
92  if (!(t.hasStaticShape() &&
94  spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
95  return false;
96  Type elementType = t.getElementType();
97  if (auto vecType = elementType.dyn_cast<VectorType>())
98  elementType = vecType.getElementType();
99  return elementType.isIntOrFloat();
100 }
101 
102 /// Returns the scope to use for atomic operations use for emulating store
103 /// operations of unsupported integer bitwidths, based on the memref
104 /// type. Returns None on failure.
106  Optional<spirv::StorageClass> storageClass =
108  t.getMemorySpaceAsInt());
109  if (!storageClass)
110  return {};
111  switch (*storageClass) {
112  case spirv::StorageClass::StorageBuffer:
113  return spirv::Scope::Device;
114  case spirv::StorageClass::Workgroup:
115  return spirv::Scope::Workgroup;
116  default: {
117  }
118  }
119  return {};
120 }
121 
122 /// Casts the given `srcInt` into a boolean value.
123 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
124  if (srcInt.getType().isInteger(1))
125  return srcInt;
126 
127  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
128  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
129 }
130 
131 /// Casts the given `srcBool` into an integer of `dstType`.
132 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
133  OpBuilder &builder) {
134  assert(srcBool.getType().isInteger(1));
135  if (dstType.isInteger(1))
136  return srcBool;
137  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
138  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
139  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Operation conversion
144 //===----------------------------------------------------------------------===//
145 
146 // Note that DRR cannot be used for the patterns in this file: we may need to
147 // convert type along the way, which requires ConversionPattern. DRR generates
148 // normal RewritePattern.
149 
150 namespace {
151 
152 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
153 /// to Workgroup memory when the size is constant. Note that this pattern needs
154 /// to be applied in a pass that runs at least at spv.module scope since it wil
155 /// ladd global variables into the spv.module.
156 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
157 public:
159 
161  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
162  ConversionPatternRewriter &rewriter) const override;
163 };
164 
165 /// Removed a deallocation if it is a supported allocation. Currently only
166 /// removes deallocation if the memory space is workgroup memory.
167 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
168 public:
170 
172  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
173  ConversionPatternRewriter &rewriter) const override;
174 };
175 
176 /// Converts memref.load to spv.Load.
177 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
178 public:
180 
182  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
183  ConversionPatternRewriter &rewriter) const override;
184 };
185 
186 /// Converts memref.load to spv.Load.
187 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
188 public:
190 
192  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
193  ConversionPatternRewriter &rewriter) const override;
194 };
195 
196 /// Converts memref.store to spv.Store on integers.
197 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
198 public:
200 
202  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override;
204 };
205 
206 /// Converts memref.store to spv.Store.
207 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
208 public:
210 
212  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override;
214 };
215 
216 } // namespace
217 
218 //===----------------------------------------------------------------------===//
219 // AllocOp
220 //===----------------------------------------------------------------------===//
221 
223 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const {
225  MemRefType allocType = operation.getType();
226  if (!isAllocationSupported(allocType))
227  return operation.emitError("unhandled allocation type");
228 
229  // Get the SPIR-V type for the allocation.
230  Type spirvType = getTypeConverter()->convertType(allocType);
231 
232  // Insert spv.GlobalVariable for this allocation.
233  Operation *parent =
234  SymbolTable::getNearestSymbolTable(operation->getParentOp());
235  if (!parent)
236  return failure();
237  Location loc = operation.getLoc();
238  spirv::GlobalVariableOp varOp;
239  {
240  OpBuilder::InsertionGuard guard(rewriter);
241  Block &entryBlock = *parent->getRegion(0).begin();
242  rewriter.setInsertionPointToStart(&entryBlock);
243  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
244  std::string varName =
245  std::string("__workgroup_mem__") +
246  std::to_string(std::distance(varOps.begin(), varOps.end()));
247  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
248  /*initializer=*/nullptr);
249  }
250 
251  // Get pointer to global variable at the current scope.
252  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
253  return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // DeallocOp
258 //===----------------------------------------------------------------------===//
259 
261 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
262  OpAdaptor adaptor,
263  ConversionPatternRewriter &rewriter) const {
264  MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
265  if (!isAllocationSupported(deallocType))
266  return operation.emitError("unhandled deallocation type");
267  rewriter.eraseOp(operation);
268  return success();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // LoadOp
273 //===----------------------------------------------------------------------===//
274 
276 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
277  ConversionPatternRewriter &rewriter) const {
278  auto loc = loadOp.getLoc();
279  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
280  if (!memrefType.getElementType().isSignlessInteger())
281  return failure();
282 
283  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
284  spirv::AccessChainOp accessChainOp =
285  spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
286  adaptor.indices(), loc, rewriter);
287 
288  if (!accessChainOp)
289  return failure();
290 
291  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
292  bool isBool = srcBits == 1;
293  if (isBool)
294  srcBits = typeConverter.getOptions().boolNumBits;
295  Type pointeeType = typeConverter.convertType(memrefType)
297  .getPointeeType();
298  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
299  Type dstType;
300  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
301  dstType = arrayType.getElementType();
302  else
303  dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
304 
305  int dstBits = dstType.getIntOrFloatBitWidth();
306  assert(dstBits % srcBits == 0);
307 
308  // If the rewrited load op has the same bit width, use the loading value
309  // directly.
310  if (srcBits == dstBits) {
311  Value loadVal =
312  rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
313  if (isBool)
314  loadVal = castIntNToBool(loc, loadVal, rewriter);
315  rewriter.replaceOp(loadOp, loadVal);
316  return success();
317  }
318 
319  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
320  // still returns a linearized accessing. If the accessing is not linearized,
321  // there will be offset issues.
322  assert(accessChainOp.indices().size() == 2);
323  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
324  srcBits, dstBits, rewriter);
325  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
326  loc, dstType, adjustedPtr,
327  loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
328  spirv::attributeName<spirv::MemoryAccess>()),
329  loadOp->getAttrOfType<IntegerAttr>("alignment"));
330 
331  // Shift the bits to the rightmost.
332  // ____XXXX________ -> ____________XXXX
333  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
334  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
335  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
336  loc, spvLoadOp.getType(), spvLoadOp, offset);
337 
338  // Apply the mask to extract corresponding bits.
339  Value mask = rewriter.create<spirv::ConstantOp>(
340  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
341  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
342 
343  // Apply sign extension on the loading value unconditionally. The signedness
344  // semantic is carried in the operator itself, we relies other pattern to
345  // handle the casting.
346  IntegerAttr shiftValueAttr =
347  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
348  Value shiftValue =
349  rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
350  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
351  shiftValue);
352  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
353  shiftValue);
354 
355  if (isBool) {
356  dstType = typeConverter.convertType(loadOp.getType());
357  mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
358  result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
359  } else if (result.getType().getIntOrFloatBitWidth() !=
360  static_cast<unsigned>(dstBits)) {
361  result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
362  }
363  rewriter.replaceOp(loadOp, result);
364 
365  assert(accessChainOp.use_empty());
366  rewriter.eraseOp(accessChainOp);
367 
368  return success();
369 }
370 
372 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
373  ConversionPatternRewriter &rewriter) const {
374  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
375  if (memrefType.getElementType().isSignlessInteger())
376  return failure();
377  auto loadPtr = spirv::getElementPtr(
378  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
379  adaptor.indices(), loadOp.getLoc(), rewriter);
380 
381  if (!loadPtr)
382  return failure();
383 
384  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
385  return success();
386 }
387 
389 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
390  ConversionPatternRewriter &rewriter) const {
391  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
392  if (!memrefType.getElementType().isSignlessInteger())
393  return failure();
394 
395  auto loc = storeOp.getLoc();
396  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
397  spirv::AccessChainOp accessChainOp =
398  spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
399  adaptor.indices(), loc, rewriter);
400 
401  if (!accessChainOp)
402  return failure();
403 
404  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
405 
406  bool isBool = srcBits == 1;
407  if (isBool)
408  srcBits = typeConverter.getOptions().boolNumBits;
409 
410  Type pointeeType = typeConverter.convertType(memrefType)
412  .getPointeeType();
413  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
414  Type dstType;
415  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
416  dstType = arrayType.getElementType();
417  else
418  dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
419 
420  int dstBits = dstType.getIntOrFloatBitWidth();
421  assert(dstBits % srcBits == 0);
422 
423  if (srcBits == dstBits) {
424  Value storeVal = adaptor.value();
425  if (isBool)
426  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
427  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
428  storeOp, accessChainOp.getResult(), storeVal);
429  return success();
430  }
431 
432  // Since there are multi threads in the processing, the emulation will be done
433  // with atomic operations. E.g., if the storing value is i8, rewrite the
434  // StoreOp to
435  // 1) load a 32-bit integer
436  // 2) clear 8 bits in the loading value
437  // 3) store 32-bit value back
438  // 4) load a 32-bit integer
439  // 5) modify 8 bits in the loading value
440  // 6) store 32-bit value back
441  // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
442  // 4 to step 6 are done by AtomicOr as another atomic step.
443  assert(accessChainOp.indices().size() == 2);
444  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
445  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
446 
447  // Create a mask to clear the destination. E.g., if it is the second i8 in
448  // i32, 0xFFFF00FF is created.
449  Value mask = rewriter.create<spirv::ConstantOp>(
450  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
451  Value clearBitsMask =
452  rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
453  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
454 
455  Value storeVal = adaptor.value();
456  if (isBool)
457  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
458  storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
459  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
460  srcBits, dstBits, rewriter);
461  Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
462  if (!scope)
463  return failure();
464  Value result = rewriter.create<spirv::AtomicAndOp>(
465  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
466  clearBitsMask);
467  result = rewriter.create<spirv::AtomicOrOp>(
468  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
469  storeVal);
470 
471  // The AtomicOrOp has no side effect. Since it is already inserted, we can
472  // just remove the original StoreOp. Note that rewriter.replaceOp()
473  // doesn't work because it only accepts that the numbers of result are the
474  // same.
475  rewriter.eraseOp(storeOp);
476 
477  assert(accessChainOp.use_empty());
478  rewriter.eraseOp(accessChainOp);
479 
480  return success();
481 }
482 
484 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const {
486  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
487  if (memrefType.getElementType().isSignlessInteger())
488  return failure();
489  auto storePtr = spirv::getElementPtr(
490  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
491  adaptor.indices(), storeOp.getLoc(), rewriter);
492 
493  if (!storePtr)
494  return failure();
495 
496  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
497  adaptor.value());
498  return success();
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // Pattern population
503 //===----------------------------------------------------------------------===//
504 
505 namespace mlir {
507  RewritePatternSet &patterns) {
508  patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
509  IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
510  typeConverter, patterns.getContext());
511 }
512 } // namespace mlir
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
static unsigned getMemorySpaceForStorageClass(spirv::StorageClass)
Returns the corresponding memory space for memref given a SPIR-V storage class.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
iterator begin()
Definition: Region.h:55
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static bool isAllocationSupported(MemRefType t)
Returns true if the allocations of type t can be lowered to SPIR-V.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static Value shiftValue(Location loc, Value value, Value offset, Value mask, int targetBits, OpBuilder &builder)
Returns the shifted targetBits-bit value with the given offset.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
static Optional< spirv::Scope > getAtomicOpScope(MemRefType t)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr...
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static Optional< spirv::StorageClass > getStorageClassForMemorySpace(unsigned space)
Returns the SPIR-V storage class given a memory space for memref.
This class helps build Operations.
Definition: Builders.h:177
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:429
MLIRContext * getContext() const
Definition: PatternMatch.h:906
Type conversion from builtin types to SPIR-V types for shader interface.
U cast() const
Definition: Types.h:250