25#include "llvm/ADT/DenseMap.h"
26#include "llvm/ADT/STLExtras.h"
31#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
32#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
49 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
50 if (varOp->getAttrOfType<UnitAttr>(
"aliased")) {
51 std::optional<uint32_t> set = varOp.getDescriptorSet();
52 std::optional<uint32_t> binding = varOp.getBinding();
54 aliasedResources[{*set, *binding}].push_back(varOp);
57 return aliasedResources;
64 auto ptrType = dyn_cast<spirv::PointerType>(type);
68 auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
69 if (!structType || structType.getNumElements() != 1)
73 dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
77 return rtArrayType.getElementType();
83static std::optional<int>
89 scalarNumBits.reserve(types.size());
90 vectorNumBits.reserve(types.size());
91 vectorIndices.reserve(types.size());
93 for (
const auto &indexedTypes : llvm::enumerate(types)) {
96 if (
auto vectorType = dyn_cast<VectorType>(type)) {
97 if (vectorType.getNumElements() % 2 != 0)
105 scalarNumBits.push_back(
106 vectorType.getElementType().getIntOrFloatBitWidth());
107 vectorNumBits.push_back(*numBytes * 8);
108 vectorIndices.push_back(indexedTypes.index());
114 if (!vectorNumBits.empty()) {
118 auto *minVal = llvm::min_element(vectorNumBits);
121 if (llvm::any_of(vectorNumBits,
122 [&](
int bits) {
return bits % *minVal != 0; }))
127 int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
128 int baseNumBits = scalarNumBits[
index];
129 if (llvm::any_of(scalarNumBits,
130 [&](
int bits) {
return bits % baseNumBits != 0; }))
138 auto *minVal = llvm::min_element(scalarNumBits);
139 if (llvm::any_of(scalarNumBits,
140 [minVal](
int64_t bit) {
return bit % *minVal != 0; }))
142 return std::distance(scalarNumBits.begin(), minVal);
163class ResourceAliasAnalysis {
167 explicit ResourceAliasAnalysis(Operation *);
171 bool shouldUnify(Operation *op)
const;
177 spirv::GlobalVariableOp
178 getCanonicalResource(
const Descriptor &descriptor)
const;
179 spirv::GlobalVariableOp
180 getCanonicalResource(spirv::GlobalVariableOp varOp)
const;
183 spirv::SPIRVType
getElementType(spirv::GlobalVariableOp varOp)
const;
188 void recordIfUnifiable(
const Descriptor &descriptor,
189 ArrayRef<spirv::GlobalVariableOp> resources);
205ResourceAliasAnalysis::ResourceAliasAnalysis(
Operation *root) {
213 for (
const auto &descriptorResource : aliasedResources) {
214 recordIfUnifiable(descriptorResource.first, descriptorResource.second);
218bool ResourceAliasAnalysis::shouldUnify(Operation *op)
const {
222 if (
auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
223 auto canonicalOp = getCanonicalResource(varOp);
224 return canonicalOp && varOp != canonicalOp;
226 if (
auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
227 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
230 return shouldUnify(varOp);
233 if (
auto acOp = dyn_cast<spirv::AccessChainOp>(op))
234 return shouldUnify(acOp.getBasePtr().getDefiningOp());
235 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op))
236 return shouldUnify(loadOp.getPtr().getDefiningOp());
237 if (
auto storeOp = dyn_cast<spirv::StoreOp>(op))
238 return shouldUnify(storeOp.getPtr().getDefiningOp());
243spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
245 auto varIt = canonicalResourceMap.find(descriptor);
246 if (varIt == canonicalResourceMap.end())
248 return varIt->second;
251spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
252 spirv::GlobalVariableOp varOp)
const {
253 auto descriptorIt = descriptorMap.find(varOp);
254 if (descriptorIt == descriptorMap.end())
256 return getCanonicalResource(descriptorIt->second);
260ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp)
const {
261 auto it = elementTypeMap.find(varOp);
262 if (it == elementTypeMap.end())
267void ResourceAliasAnalysis::recordIfUnifiable(
268 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
270 SmallVector<spirv::SPIRVType> elementTypes;
271 for (spirv::GlobalVariableOp resource : resources) {
276 auto type = cast<spirv::SPIRVType>(elementType);
277 if (!type.isScalarOrVector())
280 elementTypes.push_back(type);
288 resourceMap[descriptor].assign(resources.begin(), resources.end());
289 canonicalResourceMap[descriptor] = resources[*index];
290 for (
const auto &resource : llvm::enumerate(resources)) {
291 descriptorMap[resource.value()] = descriptor;
292 elementTypeMap[resource.value()] = elementTypes[resource.index()];
300template <
typename OpTy>
316 ConversionPatternRewriter &rewriter)
const override {
319 rewriter.eraseOp(varOp);
329 ConversionPatternRewriter &rewriter)
const override {
331 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
332 auto srcVarOp = cast<spirv::GlobalVariableOp>(
334 auto dstVarOp =
analysis.getCanonicalResource(srcVarOp);
335 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
345 ConversionPatternRewriter &rewriter)
const override {
346 auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
348 return rewriter.notifyMatchFailure(acOp,
"base ptr not addressof op");
350 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
351 auto srcVarOp = cast<spirv::GlobalVariableOp>(
353 auto dstVarOp =
analysis.getCanonicalResource(srcVarOp);
358 if (srcElemType == dstElemType ||
362 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
363 acOp, adaptor.getBasePtr(), adaptor.getIndices());
369 if (srcElemType.
isIntOrFloat() && isa<VectorType>(dstElemType)) {
376 assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
378 auto indices = llvm::to_vector<4>(acOp.getIndices());
382 int ratio = dstNumBytes / srcNumBytes;
383 auto ratioValue = spirv::ConstantOp::create(
384 rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
387 spirv::SDivOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
388 indices.push_back(spirv::SModOp::create(rewriter, loc, indexType,
389 oldIndex, ratioValue));
391 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
392 acOp, adaptor.getBasePtr(),
indices);
397 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
403 assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
405 auto indices = llvm::to_vector<4>(acOp.getIndices());
409 int ratio = srcNumBytes / dstNumBytes;
410 auto ratioValue = spirv::ConstantOp::create(
411 rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
414 spirv::IMulOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
416 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
421 return rewriter.notifyMatchFailure(
422 acOp,
"unsupported src/dst types for spirv.AccessChain");
431 ConversionPatternRewriter &rewriter)
const override {
432 auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
433 auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
434 auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
435 auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
438 auto newLoadOp = spirv::LoadOp::create(rewriter, loc, adaptor.getPtr());
439 if (srcElemType == dstElemType) {
440 rewriter.replaceOp(loadOp, newLoadOp->getResults());
445 auto castOp = spirv::BitcastOp::create(rewriter, loc, srcElemType,
446 newLoadOp.getValue());
447 rewriter.replaceOp(loadOp, castOp->getResults());
453 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
460 assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
461 int ratio = srcNumBytes / dstNumBytes;
463 return rewriter.notifyMatchFailure(loadOp,
"more than 4 components");
466 components.reserve(ratio);
467 components.push_back(newLoadOp);
469 auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
471 return rewriter.notifyMatchFailure(loadOp,
"ptr not spirv.AccessChain");
473 auto i32Type = rewriter.getI32Type();
474 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
475 auto indices = llvm::to_vector<4>(acOp.getIndices());
476 for (
int i = 1; i < ratio; ++i) {
478 indices.back() = spirv::IAddOp::create(rewriter, loc, i32Type,
480 auto componentAcOp = spirv::AccessChainOp::create(
481 rewriter, loc, acOp.getBasePtr(),
indices);
484 components.push_back(
485 spirv::LoadOp::create(rewriter, loc, componentAcOp));
493 Type vectorType = srcElemType;
494 if (!isa<VectorType>(srcElemType))
495 vectorType = VectorType::get({ratio}, dstElemType);
499 if (
auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
500 if (
auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
501 if (srcElemVecType.getElementType() !=
502 dstElemVecType.getElementType()) {
504 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
508 Type castType = srcElemVecType.getElementType();
510 castType = VectorType::get({count}, castType);
512 for (Value &c : components)
513 c = spirv::BitcastOp::create(rewriter, loc, castType, c);
516 Value vectorValue = spirv::CompositeConstructOp::create(
517 rewriter, loc, vectorType, components);
519 if (!isa<VectorType>(srcElemType))
521 spirv::BitcastOp::create(rewriter, loc, srcElemType, vectorValue);
522 rewriter.replaceOp(loadOp, vectorValue);
526 return rewriter.notifyMatchFailure(
527 loadOp,
"unsupported src/dst types for spirv.Load");
536 ConversionPatternRewriter &rewriter)
const override {
538 cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
540 cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
542 return rewriter.notifyMatchFailure(storeOp,
"not scalar type");
544 return rewriter.notifyMatchFailure(storeOp,
"different bitwidth");
547 Value value = adaptor.getValue();
548 if (srcElemType != dstElemType)
549 value = spirv::BitcastOp::create(rewriter, loc, dstElemType, value);
550 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
551 value, storeOp->getAttrs());
561class UnifyAliasedResourcePass final
562 :
public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
563 UnifyAliasedResourcePass> {
566 : getTargetEnvFn(std::move(getTargetEnv)) {}
568 void runOnOperation()
override;
574void UnifyAliasedResourcePass::runOnOperation() {
575 spirv::ModuleOp moduleOp = getOperation();
578 if (getTargetEnvFn) {
582 spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
584 bool isVulkanOnAppleDevices =
585 clientAPI == spirv::ClientAPI::Vulkan &&
587 if (clientAPI != spirv::ClientAPI::WebGPU &&
588 clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
593 ResourceAliasAnalysis &
analysis = getAnalysis<ResourceAliasAnalysis>();
595 ConversionTarget
target(*context);
596 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
597 spirv::AccessChainOp, spirv::LoadOp,
600 target.addLegalDialect<spirv::SPIRVDialect>();
603 RewritePatternSet
patterns(context);
604 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
605 ConvertLoad, ConvertStore>(
analysis, context);
607 return signalPassFailure();
614 for (
const auto &dr : resourceMap) {
615 const auto &resources = dr.second;
616 if (resources.size() == 1)
617 resources.front()->removeAttr(
"aliased");
622std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
624 return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
static Type getElementType(Type type)
Determine the element type of type.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static std::optional< int > deduceCanonicalResource(ArrayRef< spirv::SPIRVType > types)
Given a list of resource element types, returns the index of the canonical resource that all resource...
DenseMap< Descriptor, SmallVector< spirv::GlobalVariableOp > > AliasedResourceMap
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp)
Collects all aliased resources in the given SPIR-V moduleOp.
static Type getRuntimeArrayElementType(Type type)
Returns the element type if the given type is a runtime array resource: !spirv.ptr<!...
static bool areSameBitwidthScalarType(Type a, Type b)
std::pair< uint32_t, uint32_t > Descriptor
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
const ResourceAliasAnalysis & analysis
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Vendor getVendorID() const
Returns the vendor ID.
ClientAPI getClientAPI() const
Returns the client API.
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUnifyAliasedResourcePass(GetTargetEnvFn getTargetEnv=nullptr)
std::function< spirv::TargetEnvAttr(spirv::ModuleOp)> GetTargetEnvFn
Creates an operation pass that unifies access of multiple aliased resources into access of one single...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override