10 #include "../SPIRVCommon/Pattern.h"
19 using namespace index;
39 using ConvertIndexShl =
41 using ConvertIndexShrS =
43 using ConvertIndexShrU =
61 struct ConvertIndexConstantBoolOpPattern final
66 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
84 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
86 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
87 Type indexType = typeConverter->getIndexType();
89 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
107 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
110 Value n = adaptor.getLhs();
112 Value m = adaptor.getRhs();
123 Value mPos = rewriter.
create<spirv::SGreaterThanOp>(loc, m, zero);
124 Value x = rewriter.
create<spirv::SelectOp>(loc, mPos, negOne, posOne);
127 Value nPlusX = rewriter.
create<spirv::IAddOp>(loc, n, x);
128 Value nPlusXDivM = rewriter.
create<spirv::SDivOp>(loc, nPlusX, m);
129 Value posRes = rewriter.
create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
132 Value negN = rewriter.
create<spirv::ISubOp>(loc, zero, n);
133 Value negNDivM = rewriter.
create<spirv::SDivOp>(loc, negN, m);
134 Value negRes = rewriter.
create<spirv::ISubOp>(loc, zero, negNDivM);
138 Value nPos = rewriter.
create<spirv::SGreaterThanOp>(loc, n, zero);
139 Value sameSign = rewriter.
create<spirv::LogicalEqualOp>(loc, nPos, mPos);
140 Value nNonZero = rewriter.
create<spirv::INotEqualOp>(loc, n, zero);
141 Value cmp = rewriter.
create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
157 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
160 Value n = adaptor.getLhs();
162 Value m = adaptor.getRhs();
167 Value one = rewriter.
create<spirv::ConstantOp>(loc, n_type,
171 Value minusOne = rewriter.
create<spirv::ISubOp>(loc, n, one);
172 Value quotient = rewriter.
create<spirv::UDivOp>(loc, minusOne, m);
173 Value plusOne = rewriter.
create<spirv::IAddOp>(loc, quotient, one);
176 Value cmp = rewriter.
create<spirv::IEqualOp>(loc, n, zero);
193 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
196 Value n = adaptor.getLhs();
198 Value m = adaptor.getRhs();
209 Value mNeg = rewriter.
create<spirv::SLessThanOp>(loc, m, zero);
210 Value x = rewriter.
create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
213 Value xMinusN = rewriter.
create<spirv::ISubOp>(loc, x, n);
214 Value xMinusNDivM = rewriter.
create<spirv::SDivOp>(loc, xMinusN, m);
215 Value negRes = rewriter.
create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
218 Value posRes = rewriter.
create<spirv::SDivOp>(loc, n, m);
222 Value nNeg = rewriter.
create<spirv::SLessThanOp>(loc, n, zero);
223 Value diffSign = rewriter.
create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
224 Value nNonZero = rewriter.
create<spirv::INotEqualOp>(loc, n, zero);
226 Value cmp = rewriter.
create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
240 template <
typename CastOp,
typename ConvertOp>
245 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
247 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
248 Type indexType = typeConverter->getIndexType();
250 Type srcType = adaptor.getInput().getType();
251 Type dstType = op.getType();
252 if (isa<IndexType>(srcType)) {
255 if (isa<IndexType>(dstType)) {
259 if (srcType == dstType) {
260 rewriter.
replaceOp(op, adaptor.getInput());
262 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263 adaptor.getOperands());
269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
277 template <
typename ICmpOp>
278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
288 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
291 switch (op.getPred()) {
292 case IndexCmpPredicate::EQ:
293 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294 case IndexCmpPredicate::NE:
295 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296 case IndexCmpPredicate::SGE:
297 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::SGT:
299 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SLE:
301 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SLT:
303 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::UGE:
305 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::UGT:
307 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::ULE:
309 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::ULT:
311 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
325 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
327 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
328 Type indexType = typeConverter->getIndexType();
329 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
362 ConvertIndexConstantBoolOpPattern,
363 ConvertIndexConstantOpPattern,
364 ConvertIndexCeilDivSPattern,
365 ConvertIndexCeilDivUPattern,
366 ConvertIndexFloorDivSPattern,
369 ConvertIndexCmpPattern,
379 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
380 #include "mlir/Conversion/Passes.h.inc"
388 struct ConvertIndexToSPIRVPass
389 :
public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
392 void runOnOperation()
override {
395 std::unique_ptr<SPIRVConversionTarget> target =
399 options.use64bitIndex = this->use64bitIndex;
404 target->addLegalOp<UnrealizedConversionCastOp>();
407 target->addLegalDialect<spirv::SPIRVDialect>();
409 target->addIllegalDialect<index::IndexDialect>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.