10 #include "../SPIRVCommon/Pattern.h"
19 using namespace index;
39 using ConvertIndexShl =
41 using ConvertIndexShrS =
43 using ConvertIndexShrU =
61 struct ConvertIndexConstantBoolOpPattern final
66 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
84 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
86 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
87 Type indexType = typeConverter->getIndexType();
89 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
107 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
110 Value n = adaptor.getLhs();
112 Value m = adaptor.getRhs();
123 Value mPos = rewriter.
create<spirv::SGreaterThanOp>(loc, m, zero);
124 Value x = rewriter.
create<spirv::SelectOp>(loc, mPos, negOne, posOne);
127 Value nPlusX = rewriter.
create<spirv::IAddOp>(loc, n, x);
128 Value nPlusXDivM = rewriter.
create<spirv::SDivOp>(loc, nPlusX, m);
129 Value posRes = rewriter.
create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
132 Value negN = rewriter.
create<spirv::ISubOp>(loc, zero, n);
133 Value negNDivM = rewriter.
create<spirv::SDivOp>(loc, negN, m);
134 Value negRes = rewriter.
create<spirv::ISubOp>(loc, zero, negNDivM);
138 Value nPos = rewriter.
create<spirv::SGreaterThanOp>(loc, n, zero);
139 Value sameSign = rewriter.
create<spirv::LogicalEqualOp>(loc, nPos, mPos);
140 Value nNonZero = rewriter.
create<spirv::INotEqualOp>(loc, n, zero);
141 Value cmp = rewriter.
create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
157 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
160 Value n = adaptor.getLhs();
162 Value m = adaptor.getRhs();
167 Value one = rewriter.
create<spirv::ConstantOp>(loc, n_type,
171 Value minusOne = rewriter.
create<spirv::ISubOp>(loc, n, one);
172 Value quotient = rewriter.
create<spirv::UDivOp>(loc, minusOne, m);
173 Value plusOne = rewriter.
create<spirv::IAddOp>(loc, quotient, one);
176 Value cmp = rewriter.
create<spirv::IEqualOp>(loc, n, zero);
193 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
196 Value n = adaptor.getLhs();
198 Value m = adaptor.getRhs();
209 Value mNeg = rewriter.
create<spirv::SLessThanOp>(loc, m, zero);
210 Value x = rewriter.
create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
213 Value xMinusN = rewriter.
create<spirv::ISubOp>(loc, x, n);
214 Value xMinusNDivM = rewriter.
create<spirv::SDivOp>(loc, xMinusN, m);
215 Value negRes = rewriter.
create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
218 Value posRes = rewriter.
create<spirv::SDivOp>(loc, n, m);
222 Value nNeg = rewriter.
create<spirv::SLessThanOp>(loc, n, zero);
223 Value diffSign = rewriter.
create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
224 Value nNonZero = rewriter.
create<spirv::INotEqualOp>(loc, n, zero);
226 Value cmp = rewriter.
create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
240 template <
typename CastOp,
typename ConvertOp>
245 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
247 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
248 Type indexType = typeConverter->getIndexType();
250 Type srcType = adaptor.getInput().getType();
251 Type dstType = op.getType();
252 if (isa<IndexType>(srcType)) {
255 if (isa<IndexType>(dstType)) {
259 if (srcType == dstType) {
260 rewriter.
replaceOp(op, adaptor.getInput());
262 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263 adaptor.getOperands());
269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
277 template <
typename ICmpOp>
278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
288 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
291 switch (op.getPred()) {
292 case IndexCmpPredicate::EQ:
293 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294 case IndexCmpPredicate::NE:
295 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296 case IndexCmpPredicate::SGE:
297 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::SGT:
299 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SLE:
301 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SLT:
303 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::UGE:
305 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::UGT:
307 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::ULE:
309 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::ULT:
311 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
313 llvm_unreachable(
"Unknown predicate in ConvertIndexCmpPattern");
326 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
328 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
329 Type indexType = typeConverter->getIndexType();
330 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
363 ConvertIndexConstantBoolOpPattern,
364 ConvertIndexConstantOpPattern,
365 ConvertIndexCeilDivSPattern,
366 ConvertIndexCeilDivUPattern,
367 ConvertIndexFloorDivSPattern,
370 ConvertIndexCmpPattern,
372 >(typeConverter,
patterns.getContext());
380 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381 #include "mlir/Conversion/Passes.h.inc"
389 struct ConvertIndexToSPIRVPass
390 :
public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
393 void runOnOperation()
override {
396 std::unique_ptr<SPIRVConversionTarget> target =
400 options.use64bitIndex = this->use64bitIndex;
405 target->addLegalOp<UnrealizedConversionCastOp>();
408 target->addLegalDialect<spirv::SPIRVDialect>();
410 target->addIllegalDialect<index::IndexDialect>();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.