MLIR  19.0.0git
Classes | Namespaces | Macros | Functions
MemRefToSPIRV.cpp File Reference
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/Debug.h"
#include <cassert>
#include <optional>

Go to the source code of this file.

Classes

struct  MemoryRequirements
 

Namespaces

 mlir
 Include the generated interface declarations.
 

Macros

#define DEBUG_TYPE   "memref-to-spirv-pattern"
 
#define ATOMIC_CASE(kind, spirvOp)
 

Functions

static Value getOffsetForBitwidth (Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
 Returns the offset of the value in targetBits representation. More...
 
static Value adjustAccessChainForBitwidth (const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
 Returns an adjusted spirv::AccessChainOp. More...
 
static Value castBoolToIntN (Location loc, Value srcBool, Type dstType, OpBuilder &builder)
 Casts the given srcBool into an integer of dstType. More...
 
static Value shiftValue (Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
 Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type, and masked. More...
 
static bool isAllocationSupported (Operation *allocOp, MemRefType type)
 Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V. More...
 
static std::optional< spirv::Scope > getAtomicOpScope (MemRefType type)
 Returns the scope to use for atomic operations use for emulating store operations of unsupported integer bitwidths, based on the memref type. More...
 
static Value castIntNToBool (Location loc, Value srcInt, OpBuilder &builder)
 Casts the given srcInt into a boolean value. More...
 
static FailureOr< MemoryRequirementscalculateMemoryRequirements (Value accessedPtr, bool isNontemporal)
 Given an accessed SPIR-V pointer, calculates its alignment requirements, if any. More...
 
template<class LoadOrStoreOp >
static FailureOr< MemoryRequirementscalculateMemoryRequirements (Value accessedPtr, LoadOrStoreOp loadOrStoreOp)
 Given an accessed SPIR-V pointer and the original memref load/store memAccess op, calculates the alignment requirements, if any. More...
 
void mlir::populateMemRefToSPIRVPatterns (SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
 Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops. More...
 

Macro Definition Documentation

◆ ATOMIC_CASE

#define ATOMIC_CASE (   kind,
  spirvOp 
)
Value:
case arith::AtomicRMWKind::kind: \
rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
atomicOp, resultType, ptr, *scope, \
spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
break

◆ DEBUG_TYPE

#define DEBUG_TYPE   "memref-to-spirv-pattern"

Definition at line 29 of file MemRefToSPIRV.cpp.

Function Documentation

◆ adjustAccessChainForBitwidth()

static Value adjustAccessChainForBitwidth ( const SPIRVTypeConverter typeConverter,
spirv::AccessChainOp  op,
int  sourceBits,
int  targetBits,
OpBuilder builder 
)
static

Returns an adjusted spirv::AccessChainOp.

Based on the extension/capabilities, certain integer bitwidths sourceBits might not be supported. During conversion if a memref of an unsupported type is used, load/stores to this memref need to be modified to use a supported higher bitwidth targetBits and extracting the required bits. For an accessing a 1D array (spirv.array or spirv.rtarray), the last index is modified to load the bits needed. The extraction of the actual bits needed are handled separately. Note that this only works for a 1-D tensor.

Definition at line 70 of file MemRefToSPIRV.cpp.

◆ calculateMemoryRequirements() [1/2]

static FailureOr<MemoryRequirements> calculateMemoryRequirements ( Value  accessedPtr,
bool  isNontemporal 
)
static

Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.

Definition at line 458 of file MemRefToSPIRV.cpp.

References mlir::failure(), mlir::get(), mlir::Value::getContext(), mlir::Value::getType(), and None.

Referenced by calculateMemoryRequirements().

◆ calculateMemoryRequirements() [2/2]

template<class LoadOrStoreOp >
static FailureOr<MemoryRequirements> calculateMemoryRequirements ( Value  accessedPtr,
LoadOrStoreOp  loadOrStoreOp 
)
static

Given an accessed SPIR-V pointer and the original memref load/store memAccess op, calculates the alignment requirements, if any.

Takes into account the alignment attributes applied to the load/store op.

Definition at line 496 of file MemRefToSPIRV.cpp.

References calculateMemoryRequirements(), and mlir::Operation::getAttrOfType().

◆ castBoolToIntN()

static Value castBoolToIntN ( Location  loc,
Value  srcBool,
Type  dstType,
OpBuilder builder 
)
static

Casts the given srcBool into an integer of dstType.

Definition at line 88 of file MemRefToSPIRV.cpp.

References mlir::OpBuilder::createOrFold(), mlir::Value::getType(), getZero(), and mlir::Type::isInteger().

Referenced by shiftValue().

◆ castIntNToBool()

static Value castIntNToBool ( Location  loc,
Value  srcInt,
OpBuilder builder 
)
static

Casts the given srcInt into a boolean value.

Definition at line 165 of file MemRefToSPIRV.cpp.

References mlir::OpBuilder::createOrFold(), mlir::Value::getType(), and mlir::Type::isInteger().

◆ getAtomicOpScope()

static std::optional<spirv::Scope> getAtomicOpScope ( MemRefType  type)
static

Returns the scope to use for atomic operations use for emulating store operations of unsupported integer bitwidths, based on the memref type.

Returns std::nullopt on failure.

Definition at line 151 of file MemRefToSPIRV.cpp.

◆ getOffsetForBitwidth()

static Value getOffsetForBitwidth ( Location  loc,
Value  srcIdx,
int  sourceBits,
int  targetBits,
OpBuilder builder 
)
static

Returns the offset of the value in targetBits representation.

srcIdx is an index into a 1-D array with each element having sourceBits. It's assumed to be non-negative.

When accessing an element in the array treating as having elements of targetBits, multiple values are loaded in the same time. The method returns the offset where the srcIdx locates in the value. For example, if sourceBits equals to 8 and targetBits equals to 32, the x-th element is located at (x % 4) * 8. Because there are four elements in one i32, and one element has 8 bits.

Definition at line 48 of file MemRefToSPIRV.cpp.

References mlir::OpBuilder::createOrFold(), mlir::Builder::getIntegerAttr(), and mlir::Value::getType().

◆ isAllocationSupported()

static bool isAllocationSupported ( Operation allocOp,
MemRefType  type 
)
static

Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.

Definition at line 124 of file MemRefToSPIRV.cpp.

References mlir::Type::isIntOrFloat().

◆ shiftValue()

static Value shiftValue ( Location  loc,
Value  value,
Value  offset,
Value  mask,
OpBuilder builder 
)
static

Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type, and masked.

Definition at line 101 of file MemRefToSPIRV.cpp.

References castBoolToIntN(), mlir::OpBuilder::create(), mlir::OpBuilder::createOrFold(), mlir::Builder::getIntegerType(), mlir::Type::getIntOrFloatBitWidth(), and mlir::Value::getType().