25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
31 #define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
32 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
49 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
50 if (varOp->getAttrOfType<UnitAttr>(
"aliased")) {
51 std::optional<uint32_t> set = varOp.getDescriptorSet();
52 std::optional<uint32_t> binding = varOp.getBinding();
54 aliasedResources[{*set, *binding}].push_back(varOp);
57 return aliasedResources;
64 auto ptrType = dyn_cast<spirv::PointerType>(type);
68 auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
69 if (!structType || structType.getNumElements() != 1)
73 dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
77 return rtArrayType.getElementType();
83 static std::optional<int>
89 scalarNumBits.reserve(types.size());
90 vectorNumBits.reserve(types.size());
91 vectorIndices.reserve(types.size());
96 if (
auto vectorType = dyn_cast<VectorType>(type)) {
97 if (vectorType.getNumElements() % 2 != 0)
105 scalarNumBits.push_back(
106 vectorType.getElementType().getIntOrFloatBitWidth());
107 vectorNumBits.push_back(*numBytes * 8);
108 vectorIndices.push_back(indexedTypes.index());
114 if (!vectorNumBits.empty()) {
118 auto *minVal = llvm::min_element(vectorNumBits);
121 if (llvm::any_of(vectorNumBits,
122 [&](
int bits) {
return bits % *minVal != 0; }))
127 int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
128 int baseNumBits = scalarNumBits[index];
129 if (llvm::any_of(scalarNumBits,
130 [&](
int bits) {
return bits % baseNumBits != 0; }))
138 auto *minVal = llvm::min_element(scalarNumBits);
139 if (llvm::any_of(scalarNumBits,
140 [minVal](int64_t bit) {
return bit % *minVal != 0; }))
142 return std::distance(scalarNumBits.begin(), minVal);
163 class ResourceAliasAnalysis {
167 explicit ResourceAliasAnalysis(
Operation *);
177 spirv::GlobalVariableOp
178 getCanonicalResource(
const Descriptor &descriptor)
const;
179 spirv::GlobalVariableOp
180 getCanonicalResource(spirv::GlobalVariableOp varOp)
const;
188 void recordIfUnifiable(
const Descriptor &descriptor,
205 ResourceAliasAnalysis::ResourceAliasAnalysis(
Operation *root) {
213 for (
const auto &descriptorResource : aliasedResources) {
214 recordIfUnifiable(descriptorResource.first, descriptorResource.second);
218 bool ResourceAliasAnalysis::shouldUnify(
Operation *op)
const {
222 if (
auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
223 auto canonicalOp = getCanonicalResource(varOp);
224 return canonicalOp && varOp != canonicalOp;
226 if (
auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
227 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
229 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
230 return shouldUnify(varOp);
233 if (
auto acOp = dyn_cast<spirv::AccessChainOp>(op))
234 return shouldUnify(acOp.getBasePtr().getDefiningOp());
235 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op))
236 return shouldUnify(loadOp.getPtr().getDefiningOp());
237 if (
auto storeOp = dyn_cast<spirv::StoreOp>(op))
238 return shouldUnify(storeOp.getPtr().getDefiningOp());
243 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
245 auto varIt = canonicalResourceMap.find(descriptor);
246 if (varIt == canonicalResourceMap.end())
248 return varIt->second;
251 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
252 spirv::GlobalVariableOp varOp)
const {
253 auto descriptorIt = descriptorMap.find(varOp);
254 if (descriptorIt == descriptorMap.end())
256 return getCanonicalResource(descriptorIt->second);
261 auto it = elementTypeMap.find(varOp);
262 if (it == elementTypeMap.end())
267 void ResourceAliasAnalysis::recordIfUnifiable(
271 for (spirv::GlobalVariableOp resource : resources) {
276 auto type = cast<spirv::SPIRVType>(elementType);
277 if (!type.isScalarOrVector())
280 elementTypes.push_back(type);
288 resourceMap[descriptor].assign(resources.begin(), resources.end());
289 canonicalResourceMap[descriptor] = resources[*index];
291 descriptorMap[resource.value()] = descriptor;
292 elementTypeMap[resource.value()] = elementTypes[resource.index()];
300 template <
typename OpTy>
331 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
332 auto srcVarOp = cast<spirv::GlobalVariableOp>(
333 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
334 auto dstVarOp =
analysis.getCanonicalResource(srcVarOp);
346 auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
350 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
351 auto srcVarOp = cast<spirv::GlobalVariableOp>(
352 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
353 auto dstVarOp =
analysis.getCanonicalResource(srcVarOp);
358 if (srcElemType == dstElemType ||
363 acOp, adaptor.getBasePtr(), adaptor.getIndices());
369 if (srcElemType.
isIntOrFloat() && isa<VectorType>(dstElemType)) {
376 assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
378 auto indices = llvm::to_vector<4>(acOp.getIndices());
379 Value oldIndex = indices.back();
382 int ratio = dstNumBytes / srcNumBytes;
383 auto ratioValue = rewriter.
create<spirv::ConstantOp>(
387 rewriter.
create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
389 rewriter.
create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
392 acOp, adaptor.getBasePtr(), indices);
397 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
403 assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
405 auto indices = llvm::to_vector<4>(acOp.getIndices());
406 Value oldIndex = indices.back();
409 int ratio = srcNumBytes / dstNumBytes;
410 auto ratioValue = rewriter.
create<spirv::ConstantOp>(
414 rewriter.
create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
417 acOp, adaptor.getBasePtr(), indices);
422 acOp,
"unsupported src/dst types for spirv.AccessChain");
432 auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
433 auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
434 auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
435 auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
438 auto newLoadOp = rewriter.
create<spirv::LoadOp>(loc, adaptor.getPtr());
439 if (srcElemType == dstElemType) {
440 rewriter.
replaceOp(loadOp, newLoadOp->getResults());
445 auto castOp = rewriter.
create<spirv::BitcastOp>(loc, srcElemType,
446 newLoadOp.getValue());
447 rewriter.
replaceOp(loadOp, castOp->getResults());
452 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
453 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
458 int srcNumBytes = *srcElemType.getSizeInBytes();
459 int dstNumBytes = *dstElemType.getSizeInBytes();
460 assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
461 int ratio = srcNumBytes / dstNumBytes;
466 components.reserve(ratio);
467 components.push_back(newLoadOp);
469 auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
474 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
475 auto indices = llvm::to_vector<4>(acOp.getIndices());
476 for (
int i = 1; i < ratio; ++i) {
478 indices.back() = rewriter.
create<spirv::IAddOp>(
479 loc, i32Type, indices.back(), oneValue);
480 auto componentAcOp = rewriter.
create<spirv::AccessChainOp>(
481 loc, acOp.getBasePtr(), indices);
484 components.push_back(
485 rewriter.
create<spirv::LoadOp>(loc, componentAcOp));
493 Type vectorType = srcElemType;
494 if (!isa<VectorType>(srcElemType))
499 if (
auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
500 if (
auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
501 if (srcElemVecType.getElementType() !=
502 dstElemVecType.getElementType()) {
504 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
508 Type castType = srcElemVecType.getElementType();
512 for (
Value &c : components)
513 c = rewriter.
create<spirv::BitcastOp>(loc, castType, c);
516 Value vectorValue = rewriter.
create<spirv::CompositeConstructOp>(
517 loc, vectorType, components);
519 if (!isa<VectorType>(srcElemType))
521 rewriter.
create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
527 loadOp,
"unsupported src/dst types for spirv.Load");
538 cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
540 cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
541 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
547 Value value = adaptor.getValue();
548 if (srcElemType != dstElemType)
549 value = rewriter.
create<spirv::BitcastOp>(loc, dstElemType, value);
551 value, storeOp->getAttrs());
561 class UnifyAliasedResourcePass final
562 :
public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
563 UnifyAliasedResourcePass> {
566 : getTargetEnvFn(std::move(getTargetEnv)) {}
568 void runOnOperation()
override;
574 void UnifyAliasedResourcePass::runOnOperation() {
575 spirv::ModuleOp moduleOp = getOperation();
578 if (getTargetEnvFn) {
584 bool isVulkanOnAppleDevices =
585 clientAPI == spirv::ClientAPI::Vulkan &&
587 if (clientAPI != spirv::ClientAPI::WebGPU &&
588 clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
593 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
596 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
597 spirv::AccessChainOp, spirv::LoadOp,
599 [&analysis](
Operation *op) {
return !analysis.shouldUnify(op); });
600 target.addLegalDialect<spirv::SPIRVDialect>();
607 return signalPassFailure();
614 for (
const auto &dr : resourceMap) {
615 const auto &resources = dr.second;
616 if (resources.size() == 1)
617 resources.front()->removeAttr(
"aliased");
622 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
624 return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp)
Collects all aliased resources in the given SPIR-V moduleOp.
static Type getRuntimeArrayElementType(Type type)
Returns the element type if the given type is a runtime array resource: !spirv.ptr<!...
static bool areSameBitwidthScalarType(Type a, Type b)
static std::optional< int > deduceCanonicalResource(ArrayRef< spirv::SPIRVType > types)
Given a list of resource element types, returns the index of the canonical resource that all resource...
std::pair< uint32_t, uint32_t > Descriptor
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
const ResourceAliasAnalysis & analysis
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
An attribute that specifies the target version, allowed extensions and capabilities,...
Vendor getVendorID() const
Returns the vendor ID.
ClientAPI getClientAPI() const
Returns the client API.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUnifyAliasedResourcePass(GetTargetEnvFn getTargetEnv=nullptr)
std::function< spirv::TargetEnvAttr(spirv::ModuleOp)> GetTargetEnvFn
Creates an operation pass that unifies access of multiple aliased resources into access of one single...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override