TPU: gstreamer.py

File gstreamer.py, 9.6 KB (added by Cale Collins, 2 years ago)

gstreamer.py

Line 
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import sys
16import svgwrite
17import threading
18
19import gi
20gi.require_version('Gst', '1.0')
21gi.require_version('GstBase', '1.0')
22gi.require_version('Gtk', '3.0')
23from gi.repository import GLib, GObject, Gst, GstBase, Gtk
24
25GObject.threads_init()
26Gst.init(None)
27
28class GstPipeline:
29    def __init__(self, pipeline, user_function, src_size):
30        self.user_function = user_function
31        self.running = False
32        self.gstbuffer = None
33        self.sink_size = None
34        self.src_size = src_size
35        self.box = None
36        self.condition = threading.Condition()
37
38        self.pipeline = Gst.parse_launch(pipeline)
39        self.overlay = self.pipeline.get_by_name('overlay')
40        self.overlaysink = self.pipeline.get_by_name('overlaysink')
41        appsink = self.pipeline.get_by_name('appsink')
42        appsink.connect('new-sample', self.on_new_sample)
43
44        # Set up a pipeline bus watch to catch errors.
45        bus = self.pipeline.get_bus()
46        bus.add_signal_watch()
47        bus.connect('message', self.on_bus_message)
48
49        # Set up a full screen window on Coral, no-op otherwise.
50        self.setup_window()
51
52    def run(self):
53        # Start inference worker.
54        self.running = True
55        worker = threading.Thread(target=self.inference_loop)
56        worker.start()
57
58        # Run pipeline.
59        self.pipeline.set_state(Gst.State.PLAYING)
60        try:
61            Gtk.main()
62        except:
63            pass
64
65        # Clean up.
66        self.pipeline.set_state(Gst.State.NULL)
67        while GLib.MainContext.default().iteration(False):
68            pass
69        with self.condition:
70            self.running = False
71            self.condition.notify_all()
72        worker.join()
73
74    def on_bus_message(self, bus, message):
75        t = message.type
76        if t == Gst.MessageType.EOS:
77            Gtk.main_quit()
78        elif t == Gst.MessageType.WARNING:
79            err, debug = message.parse_warning()
80            sys.stderr.write('Warning: %s: %s\n' % (err, debug))
81        elif t == Gst.MessageType.ERROR:
82            err, debug = message.parse_error()
83            sys.stderr.write('Error: %s: %s\n' % (err, debug))
84            Gtk.main_quit()
85        return True
86
87    def on_new_sample(self, sink):
88        sample = sink.emit('pull-sample')
89        if not self.sink_size:
90            s = sample.get_caps().get_structure(0)
91            self.sink_size = (s.get_value('width'), s.get_value('height'))
92        with self.condition:
93            self.gstbuffer = sample.get_buffer()
94            self.condition.notify_all()
95        return Gst.FlowReturn.OK
96
97    def get_box(self):
98        if not self.box:
99            glbox = self.pipeline.get_by_name('glbox')
100            if glbox:
101                glbox = glbox.get_by_name('filter')
102            box = self.pipeline.get_by_name('box')
103            assert glbox or box
104            assert self.sink_size
105            if glbox:
106                self.box = (glbox.get_property('x'), glbox.get_property('y'),
107                        glbox.get_property('width'), glbox.get_property('height'))
108            else:
109                self.box = (-box.get_property('left'), -box.get_property('top'),
110                    self.sink_size[0] + box.get_property('left') + box.get_property('right'),
111                    self.sink_size[1] + box.get_property('top') + box.get_property('bottom'))
112        return self.box
113
114    def inference_loop(self):
115        while True:
116            with self.condition:
117                while not self.gstbuffer and self.running:
118                    self.condition.wait()
119                if not self.running:
120                    break
121                gstbuffer = self.gstbuffer
122                self.gstbuffer = None
123
124            # Passing Gst.Buffer as input tensor avoids 2 copies of it:
125            # * Python bindings copies the data when mapping gstbuffer
126            # * Numpy copies the data when creating ndarray.
127            # This requires a recent version of the python3-edgetpu package. If this
128            # raises an exception please make sure dependencies are up to date.
129            input_tensor = gstbuffer
130            svg = self.user_function(input_tensor, self.src_size, self.get_box())
131            if svg:
132                if self.overlay:
133                    self.overlay.set_property('data', svg)
134                if self.overlaysink:
135                    self.overlaysink.set_property('svg', svg)
136
137    def setup_window(self):
138        # Only set up our own window if we have Coral overlay sink in the pipeline.
139        if not self.overlaysink:
140            return
141
142        gi.require_version('GstGL', '1.0')
143        gi.require_version('GstVideo', '1.0')
144        from gi.repository import GstGL, GstVideo
145
146        # Needed to commit the wayland sub-surface.
147        def on_gl_draw(sink, widget):
148            widget.queue_draw()
149
150        # Needed to account for window chrome etc.
151        def on_widget_configure(widget, event, overlaysink):
152            allocation = widget.get_allocation()
153            overlaysink.set_render_rectangle(allocation.x, allocation.y,
154                    allocation.width, allocation.height)
155            return False
156
157        window = Gtk.Window(Gtk.WindowType.TOPLEVEL)
158        window.fullscreen()
159
160        drawing_area = Gtk.DrawingArea()
161        window.add(drawing_area)
162        drawing_area.realize()
163
164        self.overlaysink.connect('drawn', on_gl_draw, drawing_area)
165
166        # Wayland window handle.
167        wl_handle = self.overlaysink.get_wayland_window_handle(drawing_area)
168        self.overlaysink.set_window_handle(wl_handle)
169
170        # Wayland display context wrapped as a GStreamer context.
171        wl_display = self.overlaysink.get_default_wayland_display_context()
172        self.overlaysink.set_context(wl_display)
173
174        drawing_area.connect('configure-event', on_widget_configure, self.overlaysink)
175        window.connect('delete-event', Gtk.main_quit)
176        window.show_all()
177
178        # The appsink pipeline branch must use the same GL display as the screen
179        # rendering so they get the same GL context. This isn't automatically handled
180        # by GStreamer as we're the ones setting an external display handle.
181        def on_bus_message_sync(bus, message, overlaysink):
182            if message.type == Gst.MessageType.NEED_CONTEXT:
183                _, context_type = message.parse_context_type()
184                if context_type == GstGL.GL_DISPLAY_CONTEXT_TYPE:
185                    sinkelement = overlaysink.get_by_interface(GstVideo.VideoOverlay)
186                    gl_context = sinkelement.get_property('context')
187                    if gl_context:
188                        display_context = Gst.Context.new(GstGL.GL_DISPLAY_CONTEXT_TYPE, True)
189                        GstGL.context_set_gl_display(display_context, gl_context.get_display())
190                        message.src.set_context(display_context)
191            return Gst.BusSyncReply.PASS
192
193        bus = self.pipeline.get_bus()
194        bus.set_sync_handler(on_bus_message_sync, self.overlaysink)
195
196def detectCoralDevBoard():
197  try:
198    if 'MX8MQ' in open('/sys/firmware/devicetree/base/model').read():
199      print('Detected Edge TPU dev board.')
200      return True
201  except: pass
202  return False
203
204def run_pipeline(user_function,
205                 src_size,
206                 appsink_size,
207                 videosrc='/dev/video1',
208                 videofmt='raw'):
209    if videofmt == 'h264':
210        SRC_CAPS = 'video/x-h264,width={width},height={height},framerate=30/1'
211    elif videofmt == 'jpeg':
212        SRC_CAPS = 'image/jpeg,width={width},height={height},framerate=30/1'
213    else:
214        SRC_CAPS = 'video/x-raw,width={width},height={height},framerate=30/1'
215    PIPELINE = 'v4l2src device=%s ! {src_caps}'%videosrc
216
217    if detectCoralDevBoard():
218        scale_caps = None
219        PIPELINE += """ ! decodebin ! glupload ! tee name=t
220            t. ! queue ! glfilterbin filter=glbox name=glbox ! {sink_caps} ! {sink_element}
221            t. ! queue ! glsvgoverlaysink name=overlaysink
222        """
223    else:
224        scale = min(appsink_size[0] / src_size[0], appsink_size[1] / src_size[1])
225        scale = tuple(int(x * scale) for x in src_size)
226        scale_caps = 'video/x-raw,width={width},height={height}'.format(width=scale[0], height=scale[1])
227        PIPELINE += """ ! tee name=t
228            t. ! {leaky_q} ! videoconvert ! videoscale ! {scale_caps} ! videobox name=box autocrop=true
229               ! {sink_caps} ! {sink_element}
230            t. ! {leaky_q} ! videoconvert
231               ! rsvgoverlay name=overlay ! videoconvert ! jpegenc ! tcpclientsink host=172.24.24.93 port=9001
232            """
233
234    SINK_ELEMENT = 'appsink name=appsink emit-signals=true max-buffers=1 drop=true'
235    SINK_CAPS = 'video/x-raw,format=RGB,width={width},height={height}'
236    LEAKY_Q = 'queue'
237
238    src_caps = SRC_CAPS.format(width=src_size[0], height=src_size[1])
239    sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1])
240    pipeline = PIPELINE.format(leaky_q=LEAKY_Q,
241        src_caps=src_caps, sink_caps=sink_caps,
242        sink_element=SINK_ELEMENT, scale_caps=scale_caps)
243
244    print('Gstreamer pipeline:\n', pipeline)
245
246    pipeline = GstPipeline(pipeline, user_function, src_size)
247    pipeline.run()