25 #include "llvm/ADT/DenseMap.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/Support/Debug.h" 31 #define DEBUG_TYPE "spirv-unify-aliased-resource" 46 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
47 if (varOp->getAttrOfType<UnitAttr>(
"aliased")) {
48 Optional<uint32_t> set = varOp.descriptor_set();
49 Optional<uint32_t> binding = varOp.binding();
51 aliasedResources[{*set, *binding}].push_back(varOp);
54 return aliasedResources;
65 if (!structType || structType.getNumElements() != 1)
81 scalarNumBits.reserve(types.size());
82 totalNumBits.reserve(types.size());
83 bool hasVector =
false;
86 assert(type.isScalarOrVector());
87 if (
auto vectorType = type.dyn_cast<VectorType>()) {
95 scalarNumBits.push_back(
96 vectorType.getElementType().getIntOrFloatBitWidth());
97 totalNumBits.push_back(*numBytes * 8);
100 scalarNumBits.push_back(type.getIntOrFloatBitWidth());
101 totalNumBits.push_back(scalarNumBits.back());
108 if (!llvm::is_splat(scalarNumBits))
113 auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end());
116 if (llvm::any_of(totalNumBits,
117 [maxVal](int64_t bits) {
return *maxVal % bits != 0; }))
120 return std::distance(totalNumBits.begin(), maxVal);
125 auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end());
126 if (llvm::any_of(scalarNumBits,
127 [minVal](int64_t bit) {
return bit % *minVal != 0; }))
129 return std::distance(scalarNumBits.begin(), minVal);
150 class ResourceAliasAnalysis {
154 explicit ResourceAliasAnalysis(
Operation *);
164 spirv::GlobalVariableOp
165 getCanonicalResource(
const Descriptor &descriptor)
const;
166 spirv::GlobalVariableOp
167 getCanonicalResource(spirv::GlobalVariableOp varOp)
const;
175 void recordIfUnifiable(
const Descriptor &descriptor,
192 ResourceAliasAnalysis::ResourceAliasAnalysis(
Operation *root) {
200 for (
const auto &descriptorResource : aliasedResources) {
201 recordIfUnifiable(descriptorResource.first, descriptorResource.second);
205 bool ResourceAliasAnalysis::shouldUnify(
Operation *op)
const {
206 if (
auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
207 auto canonicalOp = getCanonicalResource(varOp);
208 return canonicalOp && varOp != canonicalOp;
210 if (
auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
211 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
213 return shouldUnify(varOp);
216 if (
auto acOp = dyn_cast<spirv::AccessChainOp>(op))
217 return shouldUnify(acOp.base_ptr().getDefiningOp());
218 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op))
219 return shouldUnify(loadOp.ptr().getDefiningOp());
220 if (
auto storeOp = dyn_cast<spirv::StoreOp>(op))
221 return shouldUnify(storeOp.ptr().getDefiningOp());
226 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
228 auto varIt = canonicalResourceMap.find(descriptor);
229 if (varIt == canonicalResourceMap.end())
231 return varIt->second;
234 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
235 spirv::GlobalVariableOp varOp)
const {
236 auto descriptorIt = descriptorMap.find(varOp);
237 if (descriptorIt == descriptorMap.end())
239 return getCanonicalResource(descriptorIt->second);
244 auto it = elementTypeMap.find(varOp);
245 if (it == elementTypeMap.end())
250 void ResourceAliasAnalysis::recordIfUnifiable(
254 for (spirv::GlobalVariableOp resource : resources) {
260 if (!type.isScalarOrVector())
263 elementTypes.push_back(type);
271 resourceMap[descriptor].assign(resources.begin(), resources.end());
272 canonicalResourceMap[descriptor] = resources[*index];
274 descriptorMap[resource.value()] = descriptor;
275 elementTypeMap[resource.value()] = elementTypes[resource.index()];
283 template <
typename OpTy>
314 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
315 auto srcVarOp = cast<spirv::GlobalVariableOp>(
317 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
329 auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
333 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
334 auto srcVarOp = cast<spirv::GlobalVariableOp>(
336 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
341 if (srcElemType == dstElemType ||
346 acOp, adaptor.base_ptr(), adaptor.indices());
360 assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0);
361 int ratio = dstNumBits / srcNumBits;
362 auto ratioValue = rewriter.
create<spirv::ConstantOp>(
365 auto indices = llvm::to_vector<4>(acOp.indices());
366 Value oldIndex = indices.back();
368 rewriter.
create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
370 rewriter.
create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
373 acOp, adaptor.base_ptr(), indices);
383 assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
384 int ratio = srcNumBits / dstNumBits;
385 auto ratioValue = rewriter.
create<spirv::ConstantOp>(
388 auto indices = llvm::to_vector<4>(acOp.indices());
389 Value oldIndex = indices.back();
391 rewriter.
create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
394 acOp, adaptor.base_ptr(), indices);
412 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
416 auto newLoadOp = rewriter.
create<spirv::LoadOp>(loc, adaptor.ptr());
417 if (srcElemType == dstElemType) {
418 rewriter.
replaceOp(loadOp, newLoadOp->getResults());
423 auto castOp = rewriter.
create<spirv::BitcastOp>(loc, srcElemType,
425 rewriter.
replaceOp(loadOp, castOp->getResults());
434 int srcNumBits = srcElemType.getIntOrFloatBitWidth();
435 int dstNumBits = dstElemType.getIntOrFloatBitWidth();
436 assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
437 int ratio = srcNumBits / dstNumBits;
442 components.reserve(ratio);
443 components.push_back(newLoadOp);
445 auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
450 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
451 auto indices = llvm::to_vector<4>(acOp.indices());
452 for (
int i = 1; i < ratio; ++i) {
454 indices.back() = rewriter.
create<spirv::IAddOp>(loc, i32Type,
455 indices.back(), oneValue);
457 rewriter.
create<spirv::AccessChainOp>(loc, acOp.base_ptr(), indices);
460 components.push_back(rewriter.
create<spirv::LoadOp>(loc, componentAcOp));
466 auto vectorType = VectorType::get({ratio}, dstElemType);
467 Value vectorValue = rewriter.
create<spirv::CompositeConstructOp>(
485 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
492 if (srcElemType != dstElemType)
493 value = rewriter.
create<spirv::BitcastOp>(loc, dstElemType,
value);
495 storeOp->getAttrs());
505 class UnifyAliasedResourcePass final
506 :
public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
508 void runOnOperation()
override;
512 void UnifyAliasedResourcePass::runOnOperation() {
513 spirv::ModuleOp moduleOp = getOperation();
517 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
520 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
521 spirv::AccessChainOp, spirv::LoadOp,
523 [&analysis](
Operation *op) {
return !analysis.shouldUnify(op); });
524 target.addLegalDialect<spirv::SPIRVDialect>();
531 return signalPassFailure();
538 for (
const auto &dr : resourceMap) {
539 const auto &resources = dr.second;
540 if (resources.size() == 1)
541 resources.front()->removeAttr(
"aliased");
545 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
547 return std::make_unique<UnifyAliasedResourcePass>();
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static bool areSameBitwidthScalarType(Type a, Type b)
Operation is a basic unit of execution within MLIR.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Type getPointeeType() const
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUnifyAliasedResourcePass()
Creates an operation pass that unifies access of multiple aliased resources into access of one single...
LogicalResult matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
Type getElementType() const
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
LogicalResult matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
IntegerAttr getI32IntegerAttr(int32_t value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const ResourceAliasAnalysis & analysis
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
std::pair< uint32_t, uint32_t > Descriptor
Type getElementType(unsigned) const
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp)
Collects all aliased resources in the given SPIR-V moduleOp.
static Optional< int > deduceCanonicalResource(ArrayRef< spirv::SPIRVType > types)
Given a list of resource element types, returns the index of the canonical resource that all resource...
Optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static Type getRuntimeArrayElementType(Type type)
Returns the element type if the given type is a runtime array resource: !spv.ptr<!spv.struct<!spv.rtarray<...>>>.