MLIR  15.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "memref-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Utility functions
26 //===----------------------------------------------------------------------===//
27 
28 /// Returns the offset of the value in `targetBits` representation.
29 ///
30 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
31 /// It's assumed to be non-negative.
32 ///
33 /// When accessing an element in the array treating as having elements of
34 /// `targetBits`, multiple values are loaded in the same time. The method
35 /// returns the offset where the `srcIdx` locates in the value. For example, if
36 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
37 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
38 /// element has 8 bits.
39 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
40  int targetBits, OpBuilder &builder) {
41  assert(targetBits % sourceBits == 0);
42  IntegerType targetType = builder.getIntegerType(targetBits);
43  IntegerAttr idxAttr =
44  builder.getIntegerAttr(targetType, targetBits / sourceBits);
45  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
46  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
47  auto srcBitsValue =
48  builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
49  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
50  return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
51 }
52 
53 /// Returns an adjusted spirv::AccessChainOp. Based on the
54 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
55 /// supported. During conversion if a memref of an unsupported type is used,
56 /// load/stores to this memref need to be modified to use a supported higher
57 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
58 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
59 /// bits needed. The extraction of the actual bits needed are handled
60 /// separately. Note that this only works for a 1-D tensor.
62  spirv::AccessChainOp op,
63  int sourceBits, int targetBits,
64  OpBuilder &builder) {
65  assert(targetBits % sourceBits == 0);
66  const auto loc = op.getLoc();
67  IntegerType targetType = builder.getIntegerType(targetBits);
68  IntegerAttr attr =
69  builder.getIntegerAttr(targetType, targetBits / sourceBits);
70  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
71  auto lastDim = op->getOperand(op.getNumOperands() - 1);
72  auto indices = llvm::to_vector<4>(op.indices());
73  // There are two elements if this is a 1-D tensor.
74  assert(indices.size() == 2);
75  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
76  Type t = typeConverter.convertType(op.component_ptr().getType());
77  return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
78 }
79 
80 /// Returns the shifted `targetBits`-bit value with the given offset.
81 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
82  int targetBits, OpBuilder &builder) {
83  Type targetType = builder.getIntegerType(targetBits);
84  Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
85  return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
86  offset);
87 }
88 
89 /// Returns true if the allocations of memref `type` generated from `allocOp`
90 /// can be lowered to SPIR-V.
91 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
92  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
94  spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
95  return false;
96  } else if (isa<memref::AllocaOp>(allocOp)) {
98  spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
99  return false;
100  } else {
101  return false;
102  }
103 
104  // Currently only support static shape and int or float or vector of int or
105  // float element type.
106  if (!type.hasStaticShape())
107  return false;
108 
109  Type elementType = type.getElementType();
110  if (auto vecType = elementType.dyn_cast<VectorType>())
111  elementType = vecType.getElementType();
112  return elementType.isIntOrFloat();
113 }
114 
115 /// Returns the scope to use for atomic operations use for emulating store
116 /// operations of unsupported integer bitwidths, based on the memref
117 /// type. Returns None on failure.
118 static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
119  Optional<spirv::StorageClass> storageClass =
121  type.getMemorySpaceAsInt());
122  if (!storageClass)
123  return {};
124  switch (*storageClass) {
125  case spirv::StorageClass::StorageBuffer:
126  return spirv::Scope::Device;
127  case spirv::StorageClass::Workgroup:
128  return spirv::Scope::Workgroup;
129  default: {
130  }
131  }
132  return {};
133 }
134 
135 /// Casts the given `srcInt` into a boolean value.
136 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
137  if (srcInt.getType().isInteger(1))
138  return srcInt;
139 
140  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
141  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
142 }
143 
144 /// Casts the given `srcBool` into an integer of `dstType`.
145 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
146  OpBuilder &builder) {
147  assert(srcBool.getType().isInteger(1));
148  if (dstType.isInteger(1))
149  return srcBool;
150  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
151  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
152  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // Operation conversion
157 //===----------------------------------------------------------------------===//
158 
159 // Note that DRR cannot be used for the patterns in this file: we may need to
160 // convert type along the way, which requires ConversionPattern. DRR generates
161 // normal RewritePattern.
162 
163 namespace {
164 
165 /// Converts memref.alloca to SPIR-V Function variables.
166 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
167 public:
169 
171  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
172  ConversionPatternRewriter &rewriter) const override;
173 };
174 
175 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
176 /// to Workgroup memory when the size is constant. Note that this pattern needs
177 /// to be applied in a pass that runs at least at spv.module scope since it wil
178 /// ladd global variables into the spv.module.
179 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
180 public:
182 
184  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override;
186 };
187 
188 /// Removed a deallocation if it is a supported allocation. Currently only
189 /// removes deallocation if the memory space is workgroup memory.
190 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
191 public:
193 
195  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
196  ConversionPatternRewriter &rewriter) const override;
197 };
198 
199 /// Converts memref.load to spv.Load.
200 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
201 public:
203 
205  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
206  ConversionPatternRewriter &rewriter) const override;
207 };
208 
209 /// Converts memref.load to spv.Load.
210 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
211 public:
213 
215  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
216  ConversionPatternRewriter &rewriter) const override;
217 };
218 
219 /// Converts memref.store to spv.Store on integers.
220 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
221 public:
223 
225  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
226  ConversionPatternRewriter &rewriter) const override;
227 };
228 
229 /// Converts memref.store to spv.Store.
230 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
231 public:
233 
235  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
236  ConversionPatternRewriter &rewriter) const override;
237 };
238 
239 } // namespace
240 
241 //===----------------------------------------------------------------------===//
242 // AllocaOp
243 //===----------------------------------------------------------------------===//
244 
246 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const {
248  MemRefType allocType = allocaOp.getType();
249  if (!isAllocationSupported(allocaOp, allocType))
250  return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
251 
252  // Get the SPIR-V type for the allocation.
253  Type spirvType = getTypeConverter()->convertType(allocType);
254  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
255  spirv::StorageClass::Function,
256  /*initializer=*/nullptr);
257  return success();
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // AllocOp
262 //===----------------------------------------------------------------------===//
263 
265 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
266  ConversionPatternRewriter &rewriter) const {
267  MemRefType allocType = operation.getType();
268  if (!isAllocationSupported(operation, allocType))
269  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
270 
271  // Get the SPIR-V type for the allocation.
272  Type spirvType = getTypeConverter()->convertType(allocType);
273 
274  // Insert spv.GlobalVariable for this allocation.
275  Operation *parent =
276  SymbolTable::getNearestSymbolTable(operation->getParentOp());
277  if (!parent)
278  return failure();
279  Location loc = operation.getLoc();
280  spirv::GlobalVariableOp varOp;
281  {
282  OpBuilder::InsertionGuard guard(rewriter);
283  Block &entryBlock = *parent->getRegion(0).begin();
284  rewriter.setInsertionPointToStart(&entryBlock);
285  auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
286  std::string varName =
287  std::string("__workgroup_mem__") +
288  std::to_string(std::distance(varOps.begin(), varOps.end()));
289  varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
290  /*initializer=*/nullptr);
291  }
292 
293  // Get pointer to global variable at the current scope.
294  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
295  return success();
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // DeallocOp
300 //===----------------------------------------------------------------------===//
301 
303 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
304  OpAdaptor adaptor,
305  ConversionPatternRewriter &rewriter) const {
306  MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
307  if (!isAllocationSupported(operation, deallocType))
308  return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
309  rewriter.eraseOp(operation);
310  return success();
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // LoadOp
315 //===----------------------------------------------------------------------===//
316 
318 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
319  ConversionPatternRewriter &rewriter) const {
320  auto loc = loadOp.getLoc();
321  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
322  if (!memrefType.getElementType().isSignlessInteger())
323  return failure();
324 
325  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
326  spirv::AccessChainOp accessChainOp =
327  spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
328  adaptor.indices(), loc, rewriter);
329 
330  if (!accessChainOp)
331  return failure();
332 
333  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
334  bool isBool = srcBits == 1;
335  if (isBool)
336  srcBits = typeConverter.getOptions().boolNumBits;
337  Type pointeeType = typeConverter.convertType(memrefType)
339  .getPointeeType();
340  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
341  Type dstType;
342  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
343  dstType = arrayType.getElementType();
344  else
345  dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
346 
347  int dstBits = dstType.getIntOrFloatBitWidth();
348  assert(dstBits % srcBits == 0);
349 
350  // If the rewrited load op has the same bit width, use the loading value
351  // directly.
352  if (srcBits == dstBits) {
353  Value loadVal =
354  rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
355  if (isBool)
356  loadVal = castIntNToBool(loc, loadVal, rewriter);
357  rewriter.replaceOp(loadOp, loadVal);
358  return success();
359  }
360 
361  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
362  // still returns a linearized accessing. If the accessing is not linearized,
363  // there will be offset issues.
364  assert(accessChainOp.indices().size() == 2);
365  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
366  srcBits, dstBits, rewriter);
367  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
368  loc, dstType, adjustedPtr,
369  loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
370  spirv::attributeName<spirv::MemoryAccess>()),
371  loadOp->getAttrOfType<IntegerAttr>("alignment"));
372 
373  // Shift the bits to the rightmost.
374  // ____XXXX________ -> ____________XXXX
375  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
376  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
377  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
378  loc, spvLoadOp.getType(), spvLoadOp, offset);
379 
380  // Apply the mask to extract corresponding bits.
381  Value mask = rewriter.create<spirv::ConstantOp>(
382  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
383  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
384 
385  // Apply sign extension on the loading value unconditionally. The signedness
386  // semantic is carried in the operator itself, we relies other pattern to
387  // handle the casting.
388  IntegerAttr shiftValueAttr =
389  rewriter.getIntegerAttr(dstType, dstBits - srcBits);
390  Value shiftValue =
391  rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
392  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
393  shiftValue);
394  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
395  shiftValue);
396 
397  if (isBool) {
398  dstType = typeConverter.convertType(loadOp.getType());
399  mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
400  result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
401  } else if (result.getType().getIntOrFloatBitWidth() !=
402  static_cast<unsigned>(dstBits)) {
403  result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
404  }
405  rewriter.replaceOp(loadOp, result);
406 
407  assert(accessChainOp.use_empty());
408  rewriter.eraseOp(accessChainOp);
409 
410  return success();
411 }
412 
414 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
415  ConversionPatternRewriter &rewriter) const {
416  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
417  if (memrefType.getElementType().isSignlessInteger())
418  return failure();
419  auto loadPtr = spirv::getElementPtr(
420  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
421  adaptor.indices(), loadOp.getLoc(), rewriter);
422 
423  if (!loadPtr)
424  return failure();
425 
426  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
427  return success();
428 }
429 
431 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
432  ConversionPatternRewriter &rewriter) const {
433  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
434  if (!memrefType.getElementType().isSignlessInteger())
435  return failure();
436 
437  auto loc = storeOp.getLoc();
438  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
439  spirv::AccessChainOp accessChainOp =
440  spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
441  adaptor.indices(), loc, rewriter);
442 
443  if (!accessChainOp)
444  return failure();
445 
446  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
447 
448  bool isBool = srcBits == 1;
449  if (isBool)
450  srcBits = typeConverter.getOptions().boolNumBits;
451 
452  Type pointeeType = typeConverter.convertType(memrefType)
454  .getPointeeType();
455  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
456  Type dstType;
457  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
458  dstType = arrayType.getElementType();
459  else
460  dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
461 
462  int dstBits = dstType.getIntOrFloatBitWidth();
463  assert(dstBits % srcBits == 0);
464 
465  if (srcBits == dstBits) {
466  Value storeVal = adaptor.value();
467  if (isBool)
468  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
469  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
470  storeOp, accessChainOp.getResult(), storeVal);
471  return success();
472  }
473 
474  // Since there are multi threads in the processing, the emulation will be done
475  // with atomic operations. E.g., if the storing value is i8, rewrite the
476  // StoreOp to
477  // 1) load a 32-bit integer
478  // 2) clear 8 bits in the loading value
479  // 3) store 32-bit value back
480  // 4) load a 32-bit integer
481  // 5) modify 8 bits in the loading value
482  // 6) store 32-bit value back
483  // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
484  // 4 to step 6 are done by AtomicOr as another atomic step.
485  assert(accessChainOp.indices().size() == 2);
486  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
487  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
488 
489  // Create a mask to clear the destination. E.g., if it is the second i8 in
490  // i32, 0xFFFF00FF is created.
491  Value mask = rewriter.create<spirv::ConstantOp>(
492  loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
493  Value clearBitsMask =
494  rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
495  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
496 
497  Value storeVal = adaptor.value();
498  if (isBool)
499  storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
500  storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
501  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
502  srcBits, dstBits, rewriter);
503  Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
504  if (!scope)
505  return failure();
506  Value result = rewriter.create<spirv::AtomicAndOp>(
507  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
508  clearBitsMask);
509  result = rewriter.create<spirv::AtomicOrOp>(
510  loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
511  storeVal);
512 
513  // The AtomicOrOp has no side effect. Since it is already inserted, we can
514  // just remove the original StoreOp. Note that rewriter.replaceOp()
515  // doesn't work because it only accepts that the numbers of result are the
516  // same.
517  rewriter.eraseOp(storeOp);
518 
519  assert(accessChainOp.use_empty());
520  rewriter.eraseOp(accessChainOp);
521 
522  return success();
523 }
524 
526 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
527  ConversionPatternRewriter &rewriter) const {
528  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
529  if (memrefType.getElementType().isSignlessInteger())
530  return failure();
531  auto storePtr = spirv::getElementPtr(
532  *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
533  adaptor.indices(), storeOp.getLoc(), rewriter);
534 
535  if (!storePtr)
536  return failure();
537 
538  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
539  adaptor.value());
540  return success();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // Pattern population
545 //===----------------------------------------------------------------------===//
546 
547 namespace mlir {
549  RewritePatternSet &patterns) {
550  patterns
551  .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
552  IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
553  typeConverter, patterns.getContext());
554 }
555 } // namespace mlir
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
Value getOperand(unsigned idx)
Definition: Operation.h:274
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
static unsigned getMemorySpaceForStorageClass(spirv::StorageClass)
Returns the corresponding memory space for memref given a SPIR-V storage class.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V...
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
iterator begin()
Definition: Region.h:55
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:256
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static Value shiftValue(Location loc, Value value, Value offset, Value mask, int targetBits, OpBuilder &builder)
Returns the shifted targetBits-bit value with the given offset.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr...
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static Optional< spirv::StorageClass > getStorageClassForMemorySpace(unsigned space)
Returns the SPIR-V storage class given a memory space for memref.
This class helps build Operations.
Definition: Builders.h:184
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:484
MLIRContext * getContext() const
Type conversion from builtin types to SPIR-V types for shader interface.
U cast() const
Definition: Types.h:262