18 #include "llvm/Support/Debug.h" 20 #define DEBUG_TYPE "memref-to-spirv-pattern" 41 assert(targetBits % sourceBits == 0);
45 auto idx = builder.
create<spirv::ConstantOp>(loc, targetType, idxAttr);
46 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(targetType, sourceBits);
48 builder.
create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
49 auto m = builder.
create<spirv::UModOp>(loc, srcIdx, idx);
50 return builder.
create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
62 spirv::AccessChainOp op,
63 int sourceBits,
int targetBits,
65 assert(targetBits % sourceBits == 0);
66 const auto loc = op.getLoc();
70 auto idx = builder.
create<spirv::ConstantOp>(loc, targetType, attr);
71 auto lastDim = op->
getOperand(op.getNumOperands() - 1);
72 auto indices = llvm::to_vector<4>(op.indices());
74 assert(indices.size() == 2);
75 indices.back() = builder.
create<spirv::SDivOp>(loc, lastDim, idx);
77 return builder.
create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
85 return builder.
create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
92 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
94 spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
96 }
else if (isa<memref::AllocaOp>(allocOp)) {
98 spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
106 if (!type.hasStaticShape())
109 Type elementType = type.getElementType();
110 if (
auto vecType = elementType.
dyn_cast<VectorType>())
111 elementType = vecType.getElementType();
121 type.getMemorySpaceAsInt());
124 switch (*storageClass) {
125 case spirv::StorageClass::StorageBuffer:
126 return spirv::Scope::Device;
127 case spirv::StorageClass::Workgroup:
128 return spirv::Scope::Workgroup;
140 auto one = spirv::ConstantOp::getOne(srcInt.
getType(), loc, builder);
141 return builder.
create<spirv::IEqualOp>(loc, srcInt, one);
150 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
151 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
152 return builder.
create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
171 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
184 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
195 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
205 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
215 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
225 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
235 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
246 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
248 MemRefType allocType = allocaOp.getType();
253 Type spirvType = getTypeConverter()->convertType(allocType);
255 spirv::StorageClass::Function,
265 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
267 MemRefType allocType = operation.getType();
272 Type spirvType = getTypeConverter()->convertType(allocType);
280 spirv::GlobalVariableOp varOp;
285 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
286 std::string varName =
287 std::string(
"__workgroup_mem__") +
288 std::to_string(std::distance(varOps.begin(), varOps.end()));
289 varOp = rewriter.
create<spirv::GlobalVariableOp>(loc, spirvType, varName,
303 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
306 MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
318 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
320 auto loc = loadOp.getLoc();
321 auto memrefType = loadOp.memref().getType().cast<MemRefType>();
322 if (!memrefType.getElementType().isSignlessInteger())
325 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
326 spirv::AccessChainOp accessChainOp =
328 adaptor.indices(), loc, rewriter);
333 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
334 bool isBool = srcBits == 1;
336 srcBits = typeConverter.getOptions().boolNumBits;
337 Type pointeeType = typeConverter.convertType(memrefType)
343 dstType = arrayType.getElementType();
348 assert(dstBits % srcBits == 0);
352 if (srcBits == dstBits) {
354 rewriter.
create<spirv::LoadOp>(loc, accessChainOp.getResult());
364 assert(accessChainOp.indices().size() == 2);
366 srcBits, dstBits, rewriter);
368 loc, dstType, adjustedPtr,
369 loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
370 spirv::attributeName<spirv::MemoryAccess>()),
371 loadOp->getAttrOfType<IntegerAttr>(
"alignment"));
375 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
377 Value result = rewriter.
create<spirv::ShiftRightArithmeticOp>(
378 loc, spvLoadOp.
getType(), spvLoadOp, offset);
382 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
383 result = rewriter.
create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
388 IntegerAttr shiftValueAttr =
391 rewriter.
create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
392 result = rewriter.
create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
394 result = rewriter.
create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
398 dstType = typeConverter.convertType(loadOp.getType());
399 mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
400 result = rewriter.
create<spirv::IEqualOp>(loc, result, mask);
401 }
else if (result.getType().getIntOrFloatBitWidth() !=
402 static_cast<unsigned>(dstBits)) {
403 result = rewriter.
create<spirv::SConvertOp>(loc, dstType, result);
407 assert(accessChainOp.use_empty());
408 rewriter.
eraseOp(accessChainOp);
414 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
416 auto memrefType = loadOp.memref().getType().cast<MemRefType>();
417 if (memrefType.getElementType().isSignlessInteger())
420 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
421 adaptor.indices(), loadOp.getLoc(), rewriter);
431 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
433 auto memrefType = storeOp.memref().getType().cast<MemRefType>();
434 if (!memrefType.getElementType().isSignlessInteger())
437 auto loc = storeOp.getLoc();
438 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
439 spirv::AccessChainOp accessChainOp =
441 adaptor.indices(), loc, rewriter);
446 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
448 bool isBool = srcBits == 1;
450 srcBits = typeConverter.getOptions().boolNumBits;
452 Type pointeeType = typeConverter.convertType(memrefType)
458 dstType = arrayType.getElementType();
463 assert(dstBits % srcBits == 0);
465 if (srcBits == dstBits) {
466 Value storeVal = adaptor.value();
470 storeOp, accessChainOp.getResult(), storeVal);
485 assert(accessChainOp.indices().size() == 2);
486 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
492 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
493 Value clearBitsMask =
494 rewriter.
create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
495 clearBitsMask = rewriter.
create<spirv::NotOp>(loc, dstType, clearBitsMask);
497 Value storeVal = adaptor.value();
500 storeVal =
shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
502 srcBits, dstBits, rewriter);
506 Value result = rewriter.
create<spirv::AtomicAndOp>(
507 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
509 result = rewriter.
create<spirv::AtomicOrOp>(
510 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
519 assert(accessChainOp.use_empty());
520 rewriter.
eraseOp(accessChainOp);
526 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
528 auto memrefType = storeOp.memref().getType().cast<MemRefType>();
529 if (memrefType.getElementType().isSignlessInteger())
532 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
533 adaptor.indices(), storeOp.getLoc(), rewriter);
551 .
add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
552 IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Operation is a basic unit of execution within MLIR.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Value getOperand(unsigned idx)
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
static unsigned getMemorySpaceForStorageClass(spirv::StorageClass)
Returns the corresponding memory space for memref given a SPIR-V storage class.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V...
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, int targetBits, OpBuilder &builder)
Returns the shifted targetBits-bit value with the given offset.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr...
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'. ...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class implements a pattern rewriter for use with ConversionPatterns.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static Optional< spirv::StorageClass > getStorageClassForMemorySpace(unsigned space)
Returns the SPIR-V storage class given a memory space for memref.
This class helps build Operations.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
Type conversion from builtin types to SPIR-V types for shader interface.