42using ConvertIndexShl =
44using ConvertIndexShrS =
46using ConvertIndexShrU =
64struct ConvertIndexConstantBoolOpPattern final
65 : OpConversionPattern<BoolConstantOp> {
69 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter)
const override {
71 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
83struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
87 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter)
const override {
89 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
90 Type indexType = typeConverter->getIndexType();
92 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
93 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
94 op, indexType, IntegerAttr::get(indexType, value));
106struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
110 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter)
const override {
112 Location loc = op.getLoc();
113 Value n = adaptor.getLhs();
115 Value m = adaptor.getRhs();
118 Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
119 IntegerAttr::get(nType, 0));
120 Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
121 IntegerAttr::get(nType, 1));
122 Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
123 IntegerAttr::get(nType, -1));
126 Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
127 Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
130 Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
131 Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
132 Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);
135 Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
136 Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
137 Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);
141 Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
142 Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
143 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
144 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
145 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
156struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
160 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
161 ConversionPatternRewriter &rewriter)
const override {
162 Location loc = op.getLoc();
163 Value n = adaptor.getLhs();
165 Value m = adaptor.getRhs();
168 Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
169 IntegerAttr::get(nType, 0));
170 Value one = spirv::ConstantOp::create(rewriter, loc, nType,
171 IntegerAttr::get(nType, 1));
174 Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
175 Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
176 Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);
179 Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
180 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
192struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
196 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
197 ConversionPatternRewriter &rewriter)
const override {
198 Location loc = op.getLoc();
199 Value n = adaptor.getLhs();
201 Value m = adaptor.getRhs();
204 Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
205 IntegerAttr::get(nType, 0));
206 Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
207 IntegerAttr::get(nType, 1));
208 Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
209 IntegerAttr::get(nType, -1));
212 Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
213 Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
216 Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
217 Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
218 Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);
221 Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);
225 Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
227 spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
228 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
230 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
231 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
244template <
typename CastOp,
typename ConvertOp>
245struct ConvertIndexCast final : OpConversionPattern<CastOp> {
246 using OpConversionPattern<CastOp>::OpConversionPattern;
249 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
250 ConversionPatternRewriter &rewriter)
const override {
251 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
252 Type indexType = typeConverter->getIndexType();
254 Type srcType = adaptor.getInput().getType();
255 Type dstType = op.getType();
256 if (isa<IndexType>(srcType)) {
259 if (isa<IndexType>(dstType)) {
263 if (srcType == dstType) {
264 rewriter.replaceOp(op, adaptor.getInput());
266 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
267 adaptor.getOperands());
273using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
274using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
281template <
typename ICmpOp>
282static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter) {
284 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
288struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
292 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
293 ConversionPatternRewriter &rewriter)
const override {
295 switch (op.getPred()) {
296 case IndexCmpPredicate::EQ:
297 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::NE:
299 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SGE:
301 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SGT:
303 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::SLE:
305 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::SLT:
307 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::UGE:
309 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::UGT:
311 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
312 case IndexCmpPredicate::ULE:
313 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
314 case IndexCmpPredicate::ULT:
315 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
317 llvm_unreachable(
"Unknown predicate in ConvertIndexCmpPattern");
326struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
330 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
331 ConversionPatternRewriter &rewriter)
const override {
332 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
333 Type indexType = typeConverter->getIndexType();
334 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
335 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
336 op, indexType, IntegerAttr::get(indexType, bitwidth));
363 ConvertIndexConstantBoolOpPattern,
364 ConvertIndexConstantOpPattern,
365 ConvertIndexCeilDivSPattern,
366 ConvertIndexCeilDivUPattern,
367 ConvertIndexFloorDivSPattern,
370 ConvertIndexCmpPattern,
376 patterns.
add<ConvertIndexMaxSGL, ConvertIndexMaxUGL, ConvertIndexMinSGL,
377 ConvertIndexMinUGL>(typeConverter, patterns.
getContext());
380 patterns.
add<ConvertIndexMaxSCL, ConvertIndexMaxUCL, ConvertIndexMinSCL,
381 ConvertIndexMinUCL>(typeConverter, patterns.
getContext());
389#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
390#include "mlir/Conversion/Passes.h.inc"
398struct ConvertIndexToSPIRVPass
402 void runOnOperation()
override {
403 Operation *op = getOperation();
405 std::unique_ptr<SPIRVConversionTarget>
target =
408 SPIRVConversionOptions
options;
409 options.use64bitIndex = this->use64bitIndex;
410 SPIRVTypeConverter typeConverter(targetAttr,
options);
414 target->addLegalOp<UnrealizedConversionCastOp>();
417 target->addIllegalDialect<index::IndexDialect>();
422 if (
failed(applyPartialConversion(op, *
target, std::move(patterns))))
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getType() const
Return the type of this value.
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.