25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
33 #define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
34 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
38 #define DEBUG_TYPE "spirv-unify-aliased-resource"
53 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
54 if (varOp->getAttrOfType<UnitAttr>(
"aliased")) {
55 std::optional<uint32_t> set = varOp.getDescriptorSet();
56 std::optional<uint32_t> binding = varOp.getBinding();
58 aliasedResources[{*set, *binding}].push_back(varOp);
61 return aliasedResources;
68 auto ptrType = dyn_cast<spirv::PointerType>(type);
72 auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
73 if (!structType || structType.getNumElements() != 1)
77 dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
81 return rtArrayType.getElementType();
87 static std::optional<int>
93 scalarNumBits.reserve(types.size());
94 vectorNumBits.reserve(types.size());
95 vectorIndices.reserve(types.size());
100 if (
auto vectorType = dyn_cast<VectorType>(type)) {
101 if (vectorType.getNumElements() % 2 != 0)
109 scalarNumBits.push_back(
110 vectorType.getElementType().getIntOrFloatBitWidth());
111 vectorNumBits.push_back(*numBytes * 8);
112 vectorIndices.push_back(indexedTypes.index());
118 if (!vectorNumBits.empty()) {
122 auto *minVal = llvm::min_element(vectorNumBits);
125 if (llvm::any_of(vectorNumBits,
126 [&](
int bits) {
return bits % *minVal != 0; }))
131 int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
132 int baseNumBits = scalarNumBits[index];
133 if (llvm::any_of(scalarNumBits,
134 [&](
int bits) {
return bits % baseNumBits != 0; }))
142 auto *minVal = llvm::min_element(scalarNumBits);
143 if (llvm::any_of(scalarNumBits,
144 [minVal](int64_t bit) {
return bit % *minVal != 0; }))
146 return std::distance(scalarNumBits.begin(), minVal);
167 class ResourceAliasAnalysis {
171 explicit ResourceAliasAnalysis(
Operation *);
181 spirv::GlobalVariableOp
182 getCanonicalResource(
const Descriptor &descriptor)
const;
183 spirv::GlobalVariableOp
184 getCanonicalResource(spirv::GlobalVariableOp varOp)
const;
192 void recordIfUnifiable(
const Descriptor &descriptor,
209 ResourceAliasAnalysis::ResourceAliasAnalysis(
Operation *root) {
217 for (
const auto &descriptorResource : aliasedResources) {
218 recordIfUnifiable(descriptorResource.first, descriptorResource.second);
222 bool ResourceAliasAnalysis::shouldUnify(
Operation *op)
const {
226 if (
auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
227 auto canonicalOp = getCanonicalResource(varOp);
228 return canonicalOp && varOp != canonicalOp;
230 if (
auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
231 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
233 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
234 return shouldUnify(varOp);
237 if (
auto acOp = dyn_cast<spirv::AccessChainOp>(op))
238 return shouldUnify(acOp.getBasePtr().getDefiningOp());
239 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op))
240 return shouldUnify(loadOp.getPtr().getDefiningOp());
241 if (
auto storeOp = dyn_cast<spirv::StoreOp>(op))
242 return shouldUnify(storeOp.getPtr().getDefiningOp());
247 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
249 auto varIt = canonicalResourceMap.find(descriptor);
250 if (varIt == canonicalResourceMap.end())
252 return varIt->second;
255 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
256 spirv::GlobalVariableOp varOp)
const {
257 auto descriptorIt = descriptorMap.find(varOp);
258 if (descriptorIt == descriptorMap.end())
260 return getCanonicalResource(descriptorIt->second);
265 auto it = elementTypeMap.find(varOp);
266 if (it == elementTypeMap.end())
271 void ResourceAliasAnalysis::recordIfUnifiable(
275 for (spirv::GlobalVariableOp resource : resources) {
280 auto type = cast<spirv::SPIRVType>(elementType);
281 if (!type.isScalarOrVector())
284 elementTypes.push_back(type);
292 resourceMap[descriptor].assign(resources.begin(), resources.end());
293 canonicalResourceMap[descriptor] = resources[*index];
295 descriptorMap[resource.value()] = descriptor;
296 elementTypeMap[resource.value()] = elementTypes[resource.index()];
304 template <
typename OpTy>
335 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
336 auto srcVarOp = cast<spirv::GlobalVariableOp>(
337 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
338 auto dstVarOp =
analysis.getCanonicalResource(srcVarOp);
350 auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
354 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
355 auto srcVarOp = cast<spirv::GlobalVariableOp>(
356 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
357 auto dstVarOp =
analysis.getCanonicalResource(srcVarOp);
362 if (srcElemType == dstElemType ||
367 acOp, adaptor.getBasePtr(), adaptor.getIndices());
373 if (srcElemType.
isIntOrFloat() && isa<VectorType>(dstElemType)) {
380 assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
382 auto indices = llvm::to_vector<4>(acOp.getIndices());
383 Value oldIndex = indices.back();
386 int ratio = dstNumBytes / srcNumBytes;
387 auto ratioValue = rewriter.
create<spirv::ConstantOp>(
391 rewriter.
create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
393 rewriter.
create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
396 acOp, adaptor.getBasePtr(), indices);
401 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
407 assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
409 auto indices = llvm::to_vector<4>(acOp.getIndices());
410 Value oldIndex = indices.back();
413 int ratio = srcNumBytes / dstNumBytes;
414 auto ratioValue = rewriter.
create<spirv::ConstantOp>(
418 rewriter.
create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
421 acOp, adaptor.getBasePtr(), indices);
426 acOp,
"unsupported src/dst types for spirv.AccessChain");
436 auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
437 auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
438 auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
439 auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
442 auto newLoadOp = rewriter.
create<spirv::LoadOp>(loc, adaptor.getPtr());
443 if (srcElemType == dstElemType) {
444 rewriter.
replaceOp(loadOp, newLoadOp->getResults());
449 auto castOp = rewriter.
create<spirv::BitcastOp>(loc, srcElemType,
450 newLoadOp.getValue());
451 rewriter.
replaceOp(loadOp, castOp->getResults());
456 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
457 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
462 int srcNumBytes = *srcElemType.getSizeInBytes();
463 int dstNumBytes = *dstElemType.getSizeInBytes();
464 assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
465 int ratio = srcNumBytes / dstNumBytes;
470 components.reserve(ratio);
471 components.push_back(newLoadOp);
473 auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
478 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
479 auto indices = llvm::to_vector<4>(acOp.getIndices());
480 for (
int i = 1; i < ratio; ++i) {
482 indices.back() = rewriter.
create<spirv::IAddOp>(
483 loc, i32Type, indices.back(), oneValue);
484 auto componentAcOp = rewriter.
create<spirv::AccessChainOp>(
485 loc, acOp.getBasePtr(), indices);
488 components.push_back(
489 rewriter.
create<spirv::LoadOp>(loc, componentAcOp));
497 Type vectorType = srcElemType;
498 if (!isa<VectorType>(srcElemType))
503 if (
auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
504 if (
auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
505 if (srcElemVecType.getElementType() !=
506 dstElemVecType.getElementType()) {
508 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
512 Type castType = srcElemVecType.getElementType();
516 for (
Value &c : components)
517 c = rewriter.
create<spirv::BitcastOp>(loc, castType, c);
520 Value vectorValue = rewriter.
create<spirv::CompositeConstructOp>(
521 loc, vectorType, components);
523 if (!isa<VectorType>(srcElemType))
525 rewriter.
create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
531 loadOp,
"unsupported src/dst types for spirv.Load");
542 cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
544 cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
545 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
551 Value value = adaptor.getValue();
552 if (srcElemType != dstElemType)
553 value = rewriter.
create<spirv::BitcastOp>(loc, dstElemType, value);
555 value, storeOp->getAttrs());
565 class UnifyAliasedResourcePass final
566 :
public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
567 UnifyAliasedResourcePass> {
570 : getTargetEnvFn(std::move(getTargetEnv)) {}
572 void runOnOperation()
override;
578 void UnifyAliasedResourcePass::runOnOperation() {
579 spirv::ModuleOp moduleOp = getOperation();
582 if (getTargetEnvFn) {
588 bool isVulkanOnAppleDevices =
589 clientAPI == spirv::ClientAPI::Vulkan &&
591 if (clientAPI != spirv::ClientAPI::WebGPU &&
592 clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
597 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
600 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
601 spirv::AccessChainOp, spirv::LoadOp,
603 [&analysis](
Operation *op) {
return !analysis.shouldUnify(op); });
604 target.addLegalDialect<spirv::SPIRVDialect>();
611 return signalPassFailure();
618 for (
const auto &dr : resourceMap) {
619 const auto &resources = dr.second;
620 if (resources.size() == 1)
621 resources.front()->removeAttr(
"aliased");
626 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
628 return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp)
Collects all aliased resources in the given SPIR-V moduleOp.
static Type getRuntimeArrayElementType(Type type)
Returns the element type if the given type is a runtime array resource: !spirv.ptr<!...
static bool areSameBitwidthScalarType(Type a, Type b)
static std::optional< int > deduceCanonicalResource(ArrayRef< spirv::SPIRVType > types)
Given a list of resource element types, returns the index of the canonical resource that all resource...
std::pair< uint32_t, uint32_t > Descriptor
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
const ResourceAliasAnalysis & analysis
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
An attribute that specifies the target version, allowed extensions and capabilities,...
Vendor getVendorID() const
Returns the vendor ID.
ClientAPI getClientAPI() const
Returns the client API.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
RewritePatternSet & patterns
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUnifyAliasedResourcePass(GetTargetEnvFn getTargetEnv=nullptr)
std::function< spirv::TargetEnvAttr(spirv::ModuleOp)> GetTargetEnvFn
Creates an operation pass that unifies access of multiple aliased resources into access of one single...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override