125 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
126 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
127 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
129 if (!convOp.hasPureTensorSemantics())
131 convOp,
"expected op to have pure tensor semantics");
133 if (!filterType.hasStaticShape())
135 convOp,
"expected a static shape for the filter");
137 if (!inputType.hasStaticShape())
139 "expected a static shape for the input");
144 "expected all ones for dilations");
147 Value input = convOp.getInputs()[0];
148 Value filter = convOp.getInputs()[1];
149 Value output = convOp.getOutputs()[0];
164 assert(isa<RankedTensorType>(filterType) &&
165 "expected filter type to be a ranked tensor");
166 auto tensorFilterType = cast<RankedTensorType>(filterType);
170 auto reshapedFilterType =
171 RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
172 tensorFilterType.getEncoding());
173 Value reshapedFilter = tensor::CollapseShapeOp::create(
174 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
177 RankedTensorType reshapedOutputType =
178 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
179 Value reshapedOutput = tensor::CollapseShapeOp::create(
180 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
183 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
184 inputType.getElementType());
187 auto nloops = colTensorShape.size();
189 auto parallel = utils::IteratorType::parallel;
190 auto reduction = utils::IteratorType::reduction;
200 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
201 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
202 i2cToOperExprs.
icIndex = kIndicesExprs[2];
203 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
204 i2cToOperExprs.
owIndex = mIndicesExprs[1];
208 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<
int64_t>()),
218 auto img2ColTensor = linalg::GenericOp::create(
219 rewriter, loc, colTensor.
getType(),
220 input, colTensor, img2colIndexingMaps,
223 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
231 bindDims(context, bDim, mDim, nDim, kDim);
234 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
236 parallel, reduction};
238 auto genericOp = linalg::GenericOp::create(
239 rewriter, loc, reshapedOutputType,
240 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
247 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
251 auto reshapedResult = tensor::ExpandShapeOp::create(
252 rewriter, loc, outputType,
result, outputReassocIndices);
256 return std::make_pair(img2ColTensor.getOperation(),
257 reshapedResult.getOperation());
262 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
263 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
264 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
265 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
267 if (!convOp.hasPureTensorSemantics())
269 convOp,
"expected op to have pure tensor semantics");
271 if (!filterType.hasStaticShape())
273 convOp,
"expected a static shape for the filter");
275 if (!inputType.hasStaticShape())
277 "expected a static shape for the input");
282 "expected all ones for dilations");
287 auto operandTensorType = cast<RankedTensorType>(operand.
getType());
299 Value outputTensor = tensor::EmptyOp::create(
300 rewriter, loc, targetShape, operandTensorType.getElementType());
303 nloops, utils::IteratorType::parallel);
310 auto transposedOp = linalg::GenericOp::create(
311 rewriter, loc, outputTensor.getType(),
312 operand, outputTensor, indexingMaps,
315 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
318 return transposedOp.getResult(0);
321 Value input = convOp.getInputs()[0];
322 Value filter = convOp.getInputs()[1];
323 Value output = convOp.getOutputs()[0];
326 Value inputT = transposeOperand(input, {0, 3, 1, 2});
327 Value filterT = transposeOperand(filter, {2, 0, 1});
329 cast<RankedTensorType>(filterT.getType()).getShape();
332 int n = outputShape[0];
333 int oh = outputShape[1];
334 int ow = outputShape[2];
335 int c = outputShape[3];
336 int fh = filterTShape[1];
337 int fw = filterTShape[2];
340 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
342 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
346 convOp.getStrides().getValues<
int64_t>()[0]);
348 convOp.getStrides().getValues<
int64_t>()[1]);
351 owDim * swSym + kwDim};
353 auto nloops = colTensorShape.size();
356 nloops, utils::IteratorType::parallel);
362 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
363 inputType.getElementType());
365 auto img2ColTensor = linalg::GenericOp::create(
366 rewriter, loc, colTensor.
getType(),
367 inputT, colTensor, indexingMaps,
370 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
374 {0, 1}, {2, 3}, {4, 5}};
379 auto reshapedImg2ColTensorType = RankedTensorType::get(
380 {n * c, oh * ow, fh * fw}, inputType.getElementType());
381 auto reshapedFilterTensorType =
382 RankedTensorType::get({c, fh * fw}, filterType.getElementType());
383 auto reshapedOutputTensorType =
384 RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
386 Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
387 rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
388 img2ColTensorReassocIndices);
389 Value reshapedFilterTensor =
390 tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
391 filterT, filterReassociationIndice);
392 Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
393 rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
394 outputReassociationIndice);
396 auto batchMatVecResult = linalg::BatchMatvecOp::create(
398 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
404 auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
405 rewriter, loc, transposedOutputTensor.
getType(),
406 batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
408 Value transposedResult =
409 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
412 return std::make_pair(img2ColTensor.getOperation(),
418 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
419 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
420 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
422 if (!convOp.hasPureTensorSemantics())
424 convOp,
"expected op to have pure tensor semantics");
426 if (!filterType.hasStaticShape())
428 convOp,
"expected a static shape for the filter");
430 if (!inputType.hasStaticShape())
432 "expected a static shape for the input");
437 "expected all ones for dilations");
439 Value input = convOp.getInputs()[0];
440 Value filter = convOp.getInputs()[1];
441 Value output = convOp.getOutputs()[0];
443 auto filterShape = filterType.getShape();
444 auto outputShape = outputType.getShape();
454 auto loc = convOp.getLoc();
457 assert(isa<RankedTensorType>(filterType) &&
458 "expected filter type to be a ranked tensor");
459 auto tensorFilterType = cast<RankedTensorType>(filterType);
462 auto reshapedFilterType =
463 RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
464 tensorFilterType.getEncoding());
465 Value reshapedFilter = tensor::CollapseShapeOp::create(
466 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
469 auto reshapedOutputType =
470 RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
471 Value reshapedOutput = tensor::CollapseShapeOp::create(
472 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
476 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
477 inputType.getElementType());
479 auto nloops = colTensorShape.size();
481 auto parallel = utils::IteratorType::parallel;
482 auto reduction = utils::IteratorType::reduction;
493 i2cToOperExprs.
icIndex = kIndicesExprs[0];
494 i2cToOperExprs.
fhIndex = kIndicesExprs[1];
495 i2cToOperExprs.
fwIndex = kIndicesExprs[2];
496 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
497 i2cToOperExprs.
owIndex = mIndicesExprs[1];
499 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<
int64_t>()),
509 auto img2ColTensor = linalg::GenericOp::create(
510 rewriter, loc, colTensor.
getType(),
511 input, colTensor, img2colIndexingMaps,
514 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
522 bindDims(context, bDim, mDim, nDim, kDim);
525 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
527 parallel, reduction};
528 auto genericOp = linalg::GenericOp::create(
529 rewriter, loc, reshapedOutputType,
530 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
537 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
541 auto reshapedResult = tensor::ExpandShapeOp::create(
542 rewriter, loc, outputType,
result, outputReassocIndices);
546 return std::make_pair(img2ColTensor.getOperation(),
547 reshapedResult.getOperation());
552 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
553 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
554 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
556 if (!convOp.hasPureTensorSemantics())
558 convOp,
"expected op to have pure tensor semantics");
560 if (!filterType.hasStaticShape())
562 convOp,
"expected a static shape for the filter");
564 if (!inputType.hasStaticShape())
566 "expected a static shape for the input");
571 "expected all ones for dilations");
574 Value input = convOp.getInputs()[0];
575 Value filter = convOp.getInputs()[1];
576 Value output = convOp.getOutputs()[0];
591 assert(isa<RankedTensorType>(filterType) &&
592 "expected filter type to be a ranked tensor");
593 auto tensorFilterType = cast<RankedTensorType>(filterType);
598 auto reshapedFilterType =
599 RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
600 tensorFilterType.getEncoding());
601 Value reshapedFilter = tensor::CollapseShapeOp::create(
602 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
605 RankedTensorType reshapedOutputType =
606 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
607 Value reshapedOutput = tensor::CollapseShapeOp::create(
608 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
612 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
613 inputType.getElementType());
616 auto nloops = colTensorShape.size();
618 auto parallel = utils::IteratorType::parallel;
619 auto reduction = utils::IteratorType::reduction;
629 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
630 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
631 i2cToOperExprs.
icIndex = kIndicesExprs[2];
632 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
633 i2cToOperExprs.
owIndex = mIndicesExprs[1];
637 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<
int64_t>()),
646 auto img2ColTensor = linalg::GenericOp::create(
647 rewriter, loc, colTensor.
getType(),
648 input, colTensor, img2colIndexingMaps,
651 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
658 bindDims(context, bDim, mDim, nDim, kDim);
661 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
663 parallel, reduction};
665 auto genericOp = linalg::GenericOp::create(
666 rewriter, loc, reshapedOutputType,
667 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
674 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
678 auto reshapedResult = tensor::ExpandShapeOp::create(
679 rewriter, loc, outputType,
result, outputReassocIndices);
683 return std::make_pair(img2ColTensor.getOperation(),
684 reshapedResult.getOperation());
MLIRContext is the top-level object for a collection of MLIR operations.